In [None]:
!nvidia-smi

Fri Apr 26 06:15:30 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off | 00000000:00:04.0 Off |                    0 |
| N/A   31C    P0              41W / 400W |      2MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [None]:
!huggingface-cli login

In [None]:
!pip install bitsandbytes
!pip install accelerate
!pip install transformers torch
!pip install datasets
!pip install tqdm
!pip install wandb
!pip install datasets

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments

In [None]:
BATCH_SIZE = 4
GRADIENT_ACCUMULATION_STEPS = 1
WARMUP_RATIO = 0.06
NUM_EPOCHS = 1
LEARNING_RATE = 3e-5
FP16 = False  # False - I'm considering CPU inference later
LOGGING_STEPS = 10
OPTIM = "adamw_torch"
EVALUATION_STRATEGY = 'steps'
SAVE_STRATEGY = "steps"
EVAL_STEPS = 400
MAX_GRAD_NORM = 1.0
SAVE_STEPS = 4000
LR_SCHEDULER_TYPE = 'cosine'
OUTPUT_DIR = './nllb_350M'
LOAD_BEST_MODEL_AT_END=False
SAVE_TOTAL_LIMIT=1
DDP_FIND_UNUSED_PARAMETERS=False
GROUP_BY_LENGTH=False
REPORT_TO='wandb'

TEMPERATURE = 5
LAMBDA_PARAM = 0.5

In [None]:
# teacher_model = AutoModelForSeq2SeqLM.from_pretrained('facebook/nllb-200-3.3B')
# model = AutoModelForSeq2SeqLM.from_pretrained('facebook/nllb-200-distilled-600M')
# tokenizer = AutoTokenizer.from_pretrained('facebook/nllb-200-distilled-600M', src_lang='en_Latn', tgt_lang='kor_Hang')

# model.model.encoder.layers = model.model.encoder.layers[:3]
# model.model.decoder.layers = model.model.decoder.layers[:3]

# model.config.encoder_layers = 3
# model.config.decoder_layers = 3

# model.num_parameters()

In [None]:
teacher_model = AutoModelForSeq2SeqLM.from_pretrained('facebook/nllb-200-1.3B')
model = AutoModelForSeq2SeqLM.from_pretrained('dhtocks/nllb_350M_en_ko_v16')
tokenizer = AutoTokenizer.from_pretrained('facebook/nllb-200-distilled-600M', src_lang='eng_Latn', tgt_lang='kor_Hang')

teacher_model.config.forced_bos_token_id=256098
model.config.forced_bos_token_id=256098

model.num_parameters()

In [None]:
model.config

In [None]:
teacher_model.config

In [None]:
def data_prepare(dataset):
  return tokenizer(dataset['data']['eng_Latn'] , text_target=dataset['data']['kor_Hang'])

In [None]:
train_dataset = load_dataset('dhtocks/nllb_en_ko_1M_part14')

en_dataset = load_dataset("facebook/flores", 'eng_Latn')
ko_dataset = load_dataset("facebook/flores", 'kor_Hang')

en_dataset = en_dataset['dev']['sentence']
ko_dataset = ko_dataset['dev']['sentence']

eval_dataset = {'data': []}

for i in range(len(en_dataset)):
  eval_dataset['data'].append({'eng_Latn': en_dataset[i], 'kor_Hang': ko_dataset[i]})

train_dataset = Dataset.from_dict({'data': train_dataset['train']['data']})
eval_dataset = Dataset.from_dict(eval_dataset)


# Required Data Format
#
# 'train': {
#     'data': [
#         {
#             'eng_Latn': 'good morning',
#             'kor_Hang': '좋은 아침'
#         },
#         ...
#     ]
# }

train_dataset = train_dataset.map(data_prepare, remove_columns=train_dataset.column_names)
eval_dataset = eval_dataset.map(data_prepare, remove_columns=eval_dataset.column_names)

In [None]:
class DistilTrainer(Seq2SeqTrainer):
    def __init__(self, teacher_model=None, student_model=None, temperature=None, lambda_param=None,  *args, **kwargs):
        super().__init__(model=student_model, *args, **kwargs)
        self.teacher = teacher_model
        self.student = student_model
        self.loss_function = nn.CrossEntropyLoss()
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.teacher.to(device)
        self.teacher.eval()
        self.temperature = temperature
        self.lambda_param = lambda_param

    def compute_loss(self, student, inputs, return_outputs=False):
        student_output = self.student(**inputs)

        with torch.no_grad():
          teacher_output = self.teacher(**inputs)

        soft_teacher = F.softmax(teacher_output.logits / self.temperature, dim=-1)
        soft_student = F.log_softmax(student_output.logits / self.temperature, dim=-1)

        distillation_loss = self.loss_function(soft_student, soft_teacher) * (self.temperature ** 2)

        student_target_loss = student_output.loss

        loss = (1. - self.lambda_param) * student_target_loss + self.lambda_param * distillation_loss
        return (loss, student_output) if return_outputs else loss

In [None]:
os.environ["WANDB_PROJECT"] = "NLLB_DSTILLATION"
os.environ["WANDB_LOG_MODEL"] = "nllb_350M"

args = Seq2SeqTrainingArguments(
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    warmup_ratio=WARMUP_RATIO,
    num_train_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    fp16=FP16, # False - I'm considering CPU inference later
    logging_steps=LOGGING_STEPS,
    optim=OPTIM,
    evaluation_strategy=EVALUATION_STRATEGY,
    eval_steps=EVAL_STEPS,
    save_strategy=SAVE_STRATEGY,
    max_grad_norm=MAX_GRAD_NORM,
    save_steps=SAVE_STEPS,
    lr_scheduler_type=LR_SCHEDULER_TYPE,
    output_dir=OUTPUT_DIR,
    load_best_model_at_end=LOAD_BEST_MODEL_AT_END,
    save_total_limit=SAVE_TOTAL_LIMIT,
    ddp_find_unused_parameters=DDP_FIND_UNUSED_PARAMETERS,
    group_by_length=GROUP_BY_LENGTH,
    report_to=REPORT_TO
)

trainer = DistilTrainer(
    teacher_model=teacher_model,
    student_model=model,
    temperature=TEMPERATURE,
    lambda_param=LAMBDA_PARAM,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=DataCollatorForSeq2Seq(
        tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
    ),
)

teacher_model = torch.compile(teacher_model)
model = torch.compile(model)

trainer.train()

model.push_to_hub('dhtocks/nllb_350M_en_ko_v17', private=True)