Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformer keeps predicting the same token #82

Open
MagicianWu opened this issue May 14, 2024 · 8 comments
Open

Transformer keeps predicting the same token #82

MagicianWu opened this issue May 14, 2024 · 8 comments

Comments

@MagicianWu
Copy link

MagicianWu commented May 14, 2024

Hi @lucidrains and @MarcusLoppe,

I have succefully trained a meshAutoencoder, where validation loss is as low as training loss. I used around 28000 meshes from shapenet for training and validation. And the maximum face is 800. The loss graphes are below.
image
image
I was using the look-up free quantizer. I wonder if it is normal to see the fluctuation in the commit_loss and in the end the commit_loss stays near -1. The commit_loss_weight is set as 0.5.

Afterwards, when I am training the meshtransformer without text condition, I encounter a problem that I could not solve after trying many times.
The transformer keeps generating the same token repeadly during inference. And I have no idea why.

image
I spent nearly 4 days training on a 8-GPU server. And the training loss would not go down below even 1.

Below is a config for my training for both autoencoder and transformer:

model:
  target: meshgpt_pytorch.meshgpt_pytorch.MeshAutoencoder
  params:
    num_discrete_coors: 128
    decoder_dims_through_depth:
    - value: 128
      repeat: 3
    - value: 192
      repeat: 4
    - value: 256
      repeat: 23
    - value: 384
      repeat: 3
    dim_codebook: 384
    codebook_size: 16384
    dim_area_embed: 16
    dim_coor_embed: 16
    dim_normal_embed: 24
    dim_angle_embed: 8
    commit_loss_weight: 0.5
    use_residual_lfq: true
training:
  batch_size: 25
  num_train_steps: 1200
  checkpoint_every: 50
  grad_accum_every: 1
  val_every: 10
  checkpoint_folder: ./checkpoints/28418_v1368_f800_b25_ag_LFQ_small
  use_wandb: true
data:
  json_path: /home/jiaqi/meshgpt-pytorch/model_paths_28418_v1368_f800.json
  augmented: true
MeshTransformer:
  params:
    dim: 1024
    coarse_pre_gateloop_depth: 6
    fine_pre_gateloop_depth: 4
    attn_depth: 24 # 12
    # attn_heads : 24 # 16
    # attn_dim_head : 128 # 64
    fine_attn_depth: 8
    fine_attn_dim_head: 64
    fine_attn_heads: 16
    dropout: 0
  training:
    batch_size: 6
    grad_accum_every: 1
    num_train_steps: 1200
    checkpoint_every: 5
    val_every: 5
    checkpoint_folder: ./checkpoints/tf_table_8919_v703_f800_bs6_ag_small24_clip
    learning_rate: 2.0e-04
    use_wandb: true
  data:
    json_path: /home/jiaqi/meshgpt-pytorch/model_paths_table_8919_v703_f800_split.json
    augmented: true
  finetune_training:
    learning_rate: 1e-02
    batch_size: 7
    grad_accum_every: 1
    num_train_steps: 200
    checkpoint_every: 20
    val_every: 5
    checkpoint_folder: ./checkpoints/tf_finetune_8919_v703_f800_bs7_ag
    use_wandb: true
  finetune_data:
    json_path: /home/jiaqi/meshgpt-pytorch/model_paths_table_8919_v703_f800.json
    augmented: true
@MagicianWu
Copy link
Author

I guess it is related to issue #80, I will update my repo and retrain on a small dataset.

@MarcusLoppe
Copy link
Contributor

MarcusLoppe commented May 14, 2024

Hi @lucidrains and @MarcusLoppe,

I have succefully trained a meshAutoencoder, where validation loss is as low as training loss. I used around 28000 meshes from shapenet for training and validation. And the maximum face is 800. The loss graphes are below.

I was using the look-up free quantizer. I wonder if it is normal to see the fluctuation in the commit_loss and in the end the commit_loss stays near -1. The commit_loss_weight is set as 0.5.

Hi, awesome you have some resource to be able to train the model!

  1. What reconstruction loss were you able to achieve with the auto-encoder?

About the commit loss weight, the LFQ uses a 'diversity_gamma' variable (default 1.0) which can be passed during it's creation.
The purpose is to add a 'safety net' to the quantizer so the loss is not as big when it tries to explore other ways to assign the codes. It's applied as following which explains why you get -1 loss.
entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy

The higher the commit loss weight the more importance is given to how good the encoder can quantize the mesh.
total_loss = recon_loss + commit_loss.sum() * self.commit_loss_weight

During my training for 13k meshes with a 2k codebook I used a very high commit loss at like 0.45 which then I reduced gradually (to 0.2) so the training focused more on the reconstruction loss.
I preferable want the commit loss to be under 1.5 but above 0 which this tells me that it's having a 'too easy' time training and more focus can be given to the decoder loss.

If you experience a higher reconstruction loss with the auto-encoder you might want to take a look at the model below which has been the most successful during testing.

  • Using a higher layer count for the decoder seems to help to capture more detail.
  • I hyperparameter tuned the face embedding dims and the most important of them is the area and coord, the normal is less important and angle seem to have very little effect.
  • The attention layers will also help with it's 'memory'.
  • My aim is too keep the encoder as small as simple as possible since this seems to help with outputting simplistic codes for both the decoder and transformer.
  • The dim size of 192 for the codebook seems to be enough for the decoder to learn the relationships, having to high dim count will make the decoder 'overthink' and helps it create a "world-view" that is simple which makes it easier to train.
autoencoder = MeshAutoencoder( 
        decoder_dims_through_depth =  (128,) * 6 + (192,) * 12 + (256,) * 24 + (384,) * 6,   
        # codebook_size = 2048, for the 250 face dataset, more face count probably requires 16k. 
        dim_codebook = 192,  
        dim_area_embed = 16,
        dim_coor_embed = 16, 
        dim_normal_embed = 16,
        dim_angle_embed = 8,
    
        attn_decoder_depth  = 4,
        attn_encoder_depth = 2
)

Afterwards, when I am training the meshtransformer without text condition, I encounter a problem that I could not solve after trying many times. The transformer keeps generating the same token repeadly during inference. And I have no idea why.

<nearly 4 days training on a 8-GPU server. And the training loss would not go down below even 1.

Below is a config for my training for both autoencoder and transformer:

  1. That seems like quite the problem, the 28k models, how many times are they augmented?
    I prefer to have a high augmentation count for the autoencoder training so it can create a robust 'mapping' of the mesh and how to translate it to codes, however for the transformer I pre-train with the same high augmentation count.
    Afterwards I 'fine-tune' it in which I lower the augmentation count to x1 to x10 augmentations dependant on how novel the generated mesh should be.

I have a idea about the inference issue, since the model is probabilistic, at the start it probably generate slightly wrong tokens and then becomes a snowball effect since the sequence becomes out of distribution (never seen before) and then the probabilities 'goes crazy'.
If you want to check how the actual training is progressing, you can use the code below that will let you view the results without the auto-regressive.

About your transformer setup, use the text conditioner model CLIP since it has a high level of distances between the embeddings and also set the "text_condition_cond_drop_prob" so it's at 0.0 which will help with the text conditioning.
You have commented over the attention heads and it's dim size which would default to 12 and 64, however since the dim is 1024 it creates a uneven dim size (1024/ 12 = 85.33), it's probably best to keep the head dim size to 64, so if the dim is 1024 / 64 = 16 attention heads.
The same goes for the fine-decoder which uses the same dim size as the main decoder.

text_condition_model_types = "clip",  
text_condition_cond_drop_prob = 0.0, 

About the training setup, it's recommended that you have a effective batch size of 64 (as per paper) so it's more generalized and will have less of a 'knee-jerk' reaction when it's hit with a high loss.
You can use the grad_accum to achieve a higher batch size without more VRAM usage. If you have a batch size of 8 and grad accum set at 8 you will then have a effective batch size of 8 * 8 = 64.
However in your case you can set it to 10 to achieve a batch size of 60,

from einops import rearrange, repeat, reduce, pack, unpack
import torch.nn.functional as F
from einops.layers.torch import Rearrange
from meshgpt_pytorch import mesh_render   
folder = "/kaggle/working/"   

coords = []   
for item in dataset.data[:5]:  
    transformer.eval()
    tokenized_mesh = transformer.autoencoder.tokenize(
                vertices = item['vertices'],
                faces = item['faces'],
                face_edges = item['face_edges']
    ) 
    logits = transformer.forward_on_codes(codes = tokenized_mesh.unsqueeze(0), texts = [item["texts"]],  return_loss = False, return_cache = False) 
    predicted_tokens = logits.argmax(dim=-1)[:, :-2] 
    ground_truth = tokenized_mesh.flatten().unsqueeze(0) 
    
    ce_loss = F.cross_entropy(
            rearrange(logits[:, :-2], 'b n c -> b c n'),
            ground_truth,
            ignore_index = transformer.pad_id
        )
    print("loss", ce_loss)  
    
    codes = predicted_tokens
    is_eos_codes = (codes == transformer.eos_token_id)  
    mask = is_eos_codes.float().cumsum(dim = -1) >= 1 
    codes = codes.masked_fill(mask, transformer.pad_id) 
    
    code_len = codes.shape[-1]
    round_down_code_len = code_len // transformer.num_quantizers * transformer.num_quantizers 
    codes = codes[:, :round_down_code_len]  
    
    transformer.autoencoder.eval() 
    predicted_coords, _ = transformer.autoencoder.decode_from_codes_to_faces(codes) 
    
    coords.append(predicted_coords)    
    
mesh_render.combind_mesh(f'{folder}/text+prompt_all.obj', coords) 

@MarcusLoppe
Copy link
Contributor

I guess it is related to issue #80, I will update my repo and retrain on a small dataset.

The latests updates have actually helped with that issue, I've not tested the latest but the previous once managed to follow the text conditioning.

Btw to avoid fully retraining your model during updates and what not, you can just load previous model using strict = False and ignore the warnings or if there is a error you can just delete the affected keys.
The gateloop and the other minor layers will help the new version of the transformer to pick up where it was left of, even if you remove the main decoder layers the training loss will almost be where it was before.

pkg = torch.load(str("./transformer_2k_clip_loss_0.67.pt")) 
#del pkg['model']['sos_token']
transformer.load_state_dict(pkg['model'],strict=False)

@MagicianWu
Copy link
Author

What reconstruction loss were you able to achieve with the auto-encoder?

Both training and validation loss is reduced to around 0.34.

The purpose is to add a 'safety net' to the quantizer so the loss is not as big when it tries to explore other ways to assign the codes. It's applied as following which explains why you get -1 loss.
entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy

This refers to the formula used in paper
image

I dived into the code. entropy_aux_loss is a regularization term which encourages less uncertainty on the distribution of the same embedding and evenly usage of the codebook. Therefore per_sample_entropy is minimized to close to zero and codebook_entropy is maximized to logk(entropy case of even distribution), which is log(codebook_size) which is log(2**14) == 14 in my case. Therefore, I guess entropy_aux_loss will be minimized to around -10 ~ -14 because codebook_entropy is maximized to around 10 ~ 14, i guess.

aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
After calculating entropy_aux_loss, Each layer of quantization will return aux_loss which combines the entropy_aux_loss and commit_loss(which is a MSE loss between quantized vector and original vector). The self.entropy_loss_weight is default 0.1, self.commitment_loss_weight is set to be 0.5 in my case. Therefore, entropy_aux_loss will contribute -1 ~ -1.4 to the aux_loss together with some positive term from the commit_loss.

This is my shallow understanding about why the commit_loss ends up around -1. It may be wrong, please correct me if that is the case!

I preferable want the commit loss to be under 1.5 but above 0 which this tells me that it's having a 'too easy' time training and more focus can be given to the decoder loss.

From my above discussion, my understanding is that commit loss below 0 indicates that the codebook has a good utilization?

Using a higher layer count for the decoder seems to help to capture more detail.

I agree with this. I tried increasing the number of layers in the encoder before but the performance deteriorates. I did not use the atten layer both in decoder and encoder. I remember I tried but the convergence becomes slower? I am not sure.

That seems like quite the problem, the 28k models, how many times are they augmented?
I prefer to have a high augmentation count for the autoencoder training so it can create a robust 'mapping' of the mesh and how to translate it to codes, however for the transformer I pre-train with the same high augmentation count.
Afterwards I 'fine-tune' it in which I lower the augmentation count to x1 to x10 augmentations dependant on how novel the generated mesh should be.

For the dataset, I create the dataset on the fly. I mean I did not beforehand read all the vertices and faces and create face-edges and store these information on the disk. I think it is inflexible when I want to change the dataset and occupies a lot of storage space. I am not sure about this. The disadvantage of my approach is that the high VRAM consumption, i guess. For the data augmentation as well, it is done during the training. Worth mentioning that the augmented mesh is not fixed. And each data instance has a probability of 0.3 to be augmented. I have not tested this hyperparameter. I do not know whether this approach is feasible or not. Let me know your opinion!

The getitem function in my dataset class

def __getitem__(self, index):
        file_path = self.data_paths[index]
        vertices, faces = self.get_mesh(file_path)
        
        if self.augmentation_enabled and np.random.rand() <= 0.3:
            vertices = vertices.numpy()
            vertices = self.center_vertices(vertices)
            vertices = self.normalize_to_unit_scale(vertices)
            vertices = self.random_rotation(vertices)
            # vertices = self.random_shift(vertices)
            vertices = self.normalize_to_unit_scale(vertices)
            vertices = torch.from_numpy(vertices)
            
        return vertices, faces

@MagicianWu
Copy link
Author

You have commented over the attention heads and it's dim size which would default to 12 and 64

I just checked that the default value for attention heads(attn_heads) is 16. Still grateful to learn that it is best to keep dim size to 64

@MarcusLoppe
Copy link
Contributor

What reconstruction loss were you able to achieve with the auto-encoder?

Both training and validation loss is reduced to around 0.34.

The purpose is to add a 'safety net' to the quantizer so the loss is not as big when it tries to explore other ways to assign the codes. It's applied as following which explains why you get -1 loss.
entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy

This refers to the formula used in paper image

I dived into the code. entropy_aux_loss is a regularization term which encourages less uncertainty on the distribution of the same embedding and evenly usage of the codebook. Therefore per_sample_entropy is minimized to close to zero and codebook_entropy is maximized to logk(entropy case of even distribution), which is log(codebook_size) which is log(2**14) == 14 in my case. Therefore, I guess entropy_aux_loss will be minimized to around -10 ~ -14 because codebook_entropy is maximized to around 10 ~ 14, i guess.

aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight After calculating entropy_aux_loss, Each layer of quantization will return aux_loss which combines the entropy_aux_loss and commit_loss(which is a MSE loss between quantized vector and original vector). The self.entropy_loss_weight is default 0.1, self.commitment_loss_weight is set to be 0.5 in my case. Therefore, entropy_aux_loss will contribute -1 ~ -1.4 to the aux_loss together with some positive term from the commit_loss.

This is my shallow understanding about why the commit_loss ends up around -1. It may be wrong, please correct me if that is the case!

I'm not that great at algebra so I haven't really dived into this, more like glanced over the code and done some light debugging.
I think lucidrains probably can give you a better answer then me :)

I preferable want the commit loss to be under 1.5 but above 0 which this tells me that it's having a 'too easy' time training and more focus can be given to the decoder loss.

From my above discussion, my understanding is that commit loss below 0 indicates that the codebook has a good utilization?

That is my understanding aswell.

Using a higher layer count for the decoder seems to help to capture more detail.

I agree with this. I tried increasing the number of layers in the encoder before but the performance deteriorates. I did not use the atten layer both in decoder and encoder. I remember I tried but the convergence becomes slower? I am not sure.

I found that when I moved from the standard resnet setup and into a longer chain of resnet blocks (6,12,24,6) it really helped with the training to capture more detail with a small increase of parameters.
The attention layers will increase the training time, however the benefit outweighs the training time and it can reach lower loss rate then without them.

The encoder should be as simple as possible, the performance have always been worse every time I've increased the Graph encoder dim's or layers.

That seems like quite the problem, the 28k models, how many times are they augmented?
I prefer to have a high augmentation count for the autoencoder training so it can create a robust 'mapping' of the mesh and how to translate it to codes, however for the transformer I pre-train with the same high augmentation count.
Afterwards I 'fine-tune' it in which I lower the augmentation count to x1 to x10 augmentations dependant on how novel the generated mesh should be.

For the dataset, I create the dataset on the fly. I mean I did not beforehand read all the vertices and faces and create face-edges and store these information on the disk. I think it is inflexible when I want to change the dataset and occupies a lot of storage space. I am not sure about this. The disadvantage of my approach is that the high VRAM consumption, i guess. For the data augmentation as well, it is done during the training. Worth mentioning that the augmented mesh is not fixed. And each data instance has a probability of 0.3 to be augmented. I have not tested this hyperparameter. I do not know whether this approach is feasible or not. Let me know your opinion!

The getitem function in my dataset class

def __getitem__(self, index):
        file_path = self.data_paths[index]
        vertices, faces = self.get_mesh(file_path)
        
        if self.augmentation_enabled and np.random.rand() <= 0.3:
            vertices = vertices.numpy()
            vertices = self.center_vertices(vertices)
            vertices = self.normalize_to_unit_scale(vertices)
            vertices = self.random_rotation(vertices)
            # vertices = self.random_shift(vertices)
            vertices = self.normalize_to_unit_scale(vertices)
            vertices = torch.from_numpy(vertices)
            
        return vertices, faces

The space required for 218k models (max 250 faces) is around 213MB (incl face edges) when saving with np.savez_compressed, so I wouldn't worry about disk requirements. The only thing I things I do live before training is the codes and the text embedding since they might consume alot of disk space. I've not check how much but generating those is pretty quick.
Take a look at mesh_dataset on how to effectively tokenize and save the dataset.

One thing to mention is that you shouldn't rescale the vertices so they are -1 or 1,I keep the scale so they are within -0.95 to 0.95.
I've had some issues in the past with the vertices that is at 1.0 will not be properly undiscretized.

The performance will be worse since you'll need to create both the face edges as well as tokenize the mesh.
Reading the mesh and the augment would also add a high disk and CPU usage.
I strongly recommend that you keep the data static and preprocessed, if you use a 32 batch size it will consume 96GB VRAM just for the face edges.
The main reason why I at first created the meshdataset class was for reproducibility when testing the model and it's different hyper-parameters.

I like the idea with the dynamic dataset since it would produce a very robust model.
However when I thought bit more on it, wouldn't the autoencoder and transformer would need to learn almost all possible combinations within the 3D space?
I'm not 100% about the math but I imagine that it would be like a massive number of combinations.
That is probably why the loss is stagnated.

You have commented over the attention heads and it's dim size which would default to 12 and 64

I just checked that the default value for attention heads(attn_heads) is 16. Still grateful to learn that it is best to keep dim size to 64

Oh alright, I've must have confused it with something else :)

@MarcusLoppe
Copy link
Contributor

@MagicianWu

Hey, the issue has been resolved so the model can output as per text guidance, here is the model I've published https://huggingface.co/MarcusLoren/MeshGPT-preview

@dainel40911
Copy link

dainel40911 commented Jun 17, 2024

@MarcusLoppe
Sorry, I have a few questions. Could I add your discord to have further advise?
and I've already sent the request to you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants