In [1]:
import numpy as np

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from tqdm import tqdm
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer

from transformers import AutoTokenizer
from transformers import DataCollatorWithPadding
from transformers import AdamW

from dataset import TranslationDataset
import warnings
import gc

warnings.filterwarnings(action='ignore')
gc.collect()

0

In [2]:
torch.cuda.empty_cache()

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_path = 'voiceprint/m2m100_418M'

In [4]:
teacher_model = M2M100ForConditionalGeneration.from_pretrained(model_path)
teacher_model.to(device)

M2M100ForConditionalGeneration(
  (model): M2M100Model(
    (shared): Embedding(128112, 1024, padding_idx=1)
    (encoder): M2M100Encoder(
      (embed_tokens): Embedding(128112, 1024, padding_idx=1)
      (embed_positions): M2M100SinusoidalPositionalEmbedding()
      (layers): ModuleList(
        (0): M2M100EncoderLayer(
          (self_attn): M2M100Attention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=Tr

In [5]:
student_model = M2M100ForConditionalGeneration.from_pretrained(model_path, encoder_layers=6, decoder_layers=6)
student_model.to(device)

Some weights of the model checkpoint at voiceprint/m2m100_418M were not used when initializing M2M100ForConditionalGeneration: ['model.decoder.layers.10.self_attn.q_proj.bias', 'model.encoder.layers.11.fc1.bias', 'model.encoder.layers.10.self_attn.k_proj.weight', 'model.decoder.layers.7.encoder_attn.k_proj.bias', 'model.decoder.layers.6.fc1.bias', 'model.decoder.layers.11.encoder_attn.q_proj.bias', 'model.encoder.layers.8.self_attn.q_proj.bias', 'model.decoder.layers.10.final_layer_norm.weight', 'model.decoder.layers.11.self_attn.out_proj.bias', 'model.encoder.layers.7.final_layer_norm.weight', 'model.decoder.layers.10.encoder_attn_layer_norm.weight', 'model.encoder.layers.11.self_attn.k_proj.bias', 'model.decoder.layers.7.self_attn.v_proj.bias', 'model.encoder.layers.6.self_attn.v_proj.bias', 'model.decoder.layers.8.self_attn_layer_norm.bias', 'model.decoder.layers.11.self_attn.q_proj.bias', 'model.decoder.layers.11.final_layer_norm.bias', 'model.encoder.layers.7.self_attn_layer_norm.

M2M100ForConditionalGeneration(
  (model): M2M100Model(
    (shared): Embedding(128112, 1024, padding_idx=1)
    (encoder): M2M100Encoder(
      (embed_tokens): Embedding(128112, 1024, padding_idx=1)
      (embed_positions): M2M100SinusoidalPositionalEmbedding()
      (layers): ModuleList(
        (0): M2M100EncoderLayer(
          (self_attn): M2M100Attention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=Tr

In [6]:
tokenizer = AutoTokenizer.from_pretrained(model_path)

In [7]:
def pad(data, pad_id, max_len):
    padded_data = list(map(lambda x : torch.cat([x, torch.tensor([pad_id] * (max_len - len(x)))]), data))
    return padded_data

def collate_fn_translation(batch):
    features = {}

    input_ids = [torch.LongTensor(item['input_ids']) for item in batch]
    max_len_input_ids = max(list(map(lambda x: len(x), input_ids)))
    input_ids_padded = pad(input_ids, pad_id=1, max_len=max_len_input_ids)
    features['input_ids'] = torch.stack(input_ids_padded, dim=0).type(torch.LongTensor)
    # print(features)

    attention_mask = [torch.Tensor(item['attention_mask']) for item in batch]
    max_len_attention_mask = max(list(map(lambda x: len(x), attention_mask)))
    attention_mask_padded = pad(attention_mask, pad_id=0, max_len=max_len_attention_mask)
    features['attention_mask'] = torch.stack(attention_mask_padded, dim=0).type(torch.Tensor)
    # print(features)
    
    labels = [torch.LongTensor(item['labels']) for item in batch]
    max_len_labels = max(list(map(lambda x: len(x), labels)))
    labels_padded = pad(labels, pad_id=-100, max_len=max_len_labels)
    # features['labels'] = torch.stack(labels_padded, dim=0).type(torch.LongTensor)
    features['labels'] = torch.stack(labels_padded, dim=0).type(torch.LongTensor)

    # print(features)

    return features

In [8]:
from torchtext.data.metrics import bleu_score

def evaluate(model, tokenizer, device):
    # train_dataset = TranslationDataset(file_path='data/sns100.csv',
    #                                     mode='train',
    #                                     tokenizer=tokenizer)
    # eval_dataloader = DataLoader(train_dataset, 
    #                              shuffle=False,
    #                              batch_size=1, 
    #                              collate_fn=collate_fn_translation)
    
    dataset = TranslationDataset(file_path='data/sns100.csv',
                                 tokenizer=tokenizer)
    dataloader = DataLoader(dataset, 
                            shuffle=False,
                            batch_size=1, 
                            collate_fn=collate_fn_translation)

    bleu_scores = []
    model.eval()
    for batch in tqdm(dataloader, desc="Evaluating"):
        batch = {k: v.to(device) for k, v in batch.items()}

        ref_text = tokenizer.batch_decode(batch['labels'], skip_special_tokens=True)
        
        del batch['labels']
        generated_tokens = model.generate(**batch, forced_bos_token_id=tokenizer.get_lang_id("en"), num_beams=1)
        translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
        
        # print(f"ref: {ref_text}")
        # print(f"hyp: {translated_text}")

        candi = translated_text[0].split(' ')
        ref = ref_text[0].split(' ')
    
        bscore = bleu_score([candi], [[ref]], max_n=1, weights=[1])
        # print(f'blue_score : {bscore}')
        bleu_scores.append(bscore)
        #break
        
    results = np.mean(bleu_scores)

    return results


def train_teacher(num_epochs, device, teacher_model, tokenizer):
    # device = torch.device('cpu')
    # teacher_model = model.to(device)
    # #criterion = nn.CrossEntropyLoss()
    optimizer = AdamW(model.parameters(), lr=5e-5)

    train_dataset = TranslationDataset(file_path='data/ko2en_travel_1_training_Bleu_Grouge.json',
                        mode='train',
                        tokenizer=tokenizer)

    train_dataloader = DataLoader(
            dataset=train_dataset, 
            batch_size=2,
            shuffle=True,
            #num_workers=1,
            collate_fn=collate_fn_translation
        )

    num_epochs = num_epochs
    num_training_steps = num_epochs * len(train_dataloader)
    progress_bar = tqdm(range(num_training_steps))

    for epoch in range(num_epochs):
        teacher_model.train()
        losses = []

        pbar = tqdm(train_dataloader, total=len(train_dataloader), position=0, leave=True, desc=f'Epoch {epoch}')
        for batch in train_dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            loss.backward()

            optimizer.step()
            optimizer.zero_grad()
            progress_bar.update(1)
            # break
        # break

        try:
            avg_loss = sum(losses) / len(losses)
        except ZeroDivisionError:
            avg_loss = 0
        acc = evaluate(teacher_model, tokenizer, device)
        print(f'Loss:{avg_loss:.2f}\tBleu:{acc:.2f}')



    return teacher_model

In [9]:
# teacher_model = train_teacher(num_epochs=1, device=device, teacher_model=model, tokenizer=tokenizer)

In [10]:
import torch.nn.functional as F

def train_step(teacher_model, student_model, optimizer, divergence_loss_fn, temp, alpha, epoch, device):
#     train_dataset = TranslationDataset(file_path='data/ko2en_travel_1_training_Bleu_Grouge.json',
#                         mode='train',
#                         tokenizer=tokenizer)

#     train_dataloader = DataLoader(
#             dataset=train_dataset, 
#             batch_size=2,
#             shuffle=True,
#             #num_workers=1,
#             collate_fn=collate_fn_translation
#         )
    
    dataset = TranslationDataset(file_path='data/sns100.csv',
                                        tokenizer=tokenizer)

    dataloader = DataLoader(
                            dataset=dataset, 
                            batch_size=1,
                            shuffle=True,
                            #num_workers=1,
                            collate_fn=collate_fn_translation
                        )
    losses = []
    teacher_ppls = []
    student_ppls = []

    for batch in dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}

        teacher_model.eval()
        student_model.train()
        
        with torch.no_grad():
            teacher_outputs = teacher_model(**batch)
            teacher_logits = teacher_outputs.logits
            teacher_ppls.append(torch.exp(teacher_outputs.loss))
                
        student_outputs = student_model(**batch)
        student_loss = student_outputs.loss
        student_logits = student_outputs.logits
        student_ppls.append(torch.exp(student_loss))

        max_vocab_size = teacher_logits.shape[-1]
        distillation_loss = divergence_loss_fn(
                                                F.log_softmax(student_logits.view(-1,max_vocab_size) / temp, dim=1),
                                                F.softmax(teacher_logits.view(-1, max_vocab_size) / temp, dim=1)
                                            )
        loss = alpha * student_loss + (1 - alpha) * distillation_loss
        losses.append(loss.item())

        # backward
        optimizer.zero_grad()
        loss.backward()

        optimizer.step()
    
    avg_loss = sum(losses) / len(losses)
    teacher_avg_ppls = sum(teacher_ppls) / len(teacher_ppls)
    student_avg_ppls = sum(student_ppls) / len(student_ppls)

    return avg_loss, teacher_avg_ppls, student_avg_ppls

In [11]:
from torch import nn

def main(teacher_model, student_model, tokenizer, device, temp=7, alpha=0.3):
    # teacher_model = teacher_model.to(device)
    # student_model = student_model.to(device)
    # student_loss_fn = nn.CrossEntropyLoss()
    divergence_loss_fn = nn.KLDivLoss(reduction='batchmean')
    optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-3)

    epochs = 10
    for epoch in range(epochs):
        # loss = train_step(teacher_model, student_model, optimizer, student_loss_fn, divergence_loss_fn, temp, alpha, epoch, device)
        # acc = check_accuracy(test_loader, student, device)
        loss, teacher_ppl, student_ppl = train_step(teacher_model, student_model, optimizer, divergence_loss_fn, temp, alpha, epoch, device)
        bleu_score = evaluate(student_model, tokenizer, device)
        print(f'Loss : {loss:.2f}\tTeacher_ppl : {teacher_ppl:.2f}\tStudent_ppl : {student_ppl:.2f}\tBleu_score : {bleu_score:.2f}')

In [12]:
main(teacher_model, student_model, tokenizer, device, temp=7, alpha=0.3)

Evaluating: 100%|██████████| 100/100 [00:12<00:00,  7.70it/s]


Loss : 2.27	Teacher_ppl : 12.13	Student_ppl : 6605.48	Bleu_score : 0.03


Evaluating: 100%|██████████| 100/100 [00:10<00:00,  9.38it/s]


Loss : 1.65	Teacher_ppl : 12.13	Student_ppl : 324.76	Bleu_score : 0.01


Evaluating: 100%|██████████| 100/100 [01:16<00:00,  1.31it/s]


Loss : 1.46	Teacher_ppl : 12.13	Student_ppl : 185.73	Bleu_score : 0.01


Evaluating: 100%|██████████| 100/100 [00:45<00:00,  2.19it/s]


Loss : 1.35	Teacher_ppl : 12.13	Student_ppl : 194.93	Bleu_score : 0.02


Evaluating: 100%|██████████| 100/100 [00:56<00:00,  1.77it/s]


Loss : 1.23	Teacher_ppl : 12.13	Student_ppl : 107.02	Bleu_score : 0.02


Evaluating: 100%|██████████| 100/100 [01:40<00:00,  1.00s/it]


Loss : 1.15	Teacher_ppl : 12.13	Student_ppl : 73.46	Bleu_score : 0.01


Evaluating: 100%|██████████| 100/100 [00:03<00:00, 29.82it/s]


Loss : 1.05	Teacher_ppl : 12.13	Student_ppl : 288.92	Bleu_score : 0.02


Evaluating: 100%|██████████| 100/100 [01:46<00:00,  1.07s/it]


Loss : 0.93	Teacher_ppl : 12.13	Student_ppl : 39.55	Bleu_score : 0.00


Evaluating: 100%|██████████| 100/100 [00:38<00:00,  2.59it/s]


Loss : 0.87	Teacher_ppl : 12.13	Student_ppl : 25.27	Bleu_score : 0.03


Evaluating: 100%|██████████| 100/100 [00:06<00:00, 14.67it/s]

Loss : 0.88	Teacher_ppl : 12.13	Student_ppl : 51.13	Bleu_score : 0.04



