In [1]:
import csv
import numpy as np
import pandas as pd
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from transformers import MBart50TokenizerFast, MBartForConditionalGeneration, MBartTokenizer
import random
import os

In [2]:
import torch

In [3]:
train_data_m = 'EvaHan2023_train_data/train_24_histories_m_utf8.txt'
train_data_c = 'EvaHan2023_train_data/train_24-historoes_c_utf8.txt'

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
model_name = "facebook/mbart-large-cc25"
output_dir = './model_save/mbart-large-cc25'

In [6]:
def processdata(filename_m, filename_c):
    with open(filename_m, 'r', encoding='utf-8') as f:
        data_m = [i.strip().split('\n') for i in f.readlines()]
    with open(filename_c, 'r', encoding='utf-8') as g:
        data_c = [i.strip().split('\n') for i in g.readlines()]
    df = pd.DataFrame({'source':data_c, 'target':data_m})
    return df

In [7]:
class CustomDataset(Dataset):
    def __init__(self, data, src_lang, tgt_lang, model_name, with_labels = True):
        self.tokenizer = MBartTokenizer.from_pretrained(model_name)
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
        self.with_labels = with_labels
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        if self.with_labels:
            src = self.data.loc[index,'source']
            tgt = self.data.loc[index,'target']
            batch = self.tokenizer(src, tgt_texts = tgt, src_lang = self.src_lang, tgt_lang = self.tgt_lang, return_tensors="pt")
            # input_ids = batch["input_ids"].squeeze(0)
            # target_ids = batch["labels"].squeeze(0)
            # return input_ids, target_ids
        else:
            src = self.data.loc[index,'source']
            batch = self.tokenizer(src, src_lang = self.src_lang, return_tensors="pt")
            # input_ids = batch["input_ids"].squeeze(0)
            # return input_ids
        return batch

In [None]:
class MyModel(nn.Module):
    def __init__(self, model_name, freeze_bert = False):
        super().__init__()
        self.tokenizer = MBartTokenizer.from_pretrained(model_name)
        self.model = MBartForConditionalGeneration.from_pretrained(model_name)
        if freeze_bert:
            for p in self.model.parameters():
                p.requires_grad = False

    def forward(self, input_ids, labels):
        output = self.model(input_ids, labels=labels)
        return output.loss

    def generate(self, input_ids, labels, decoder_start_token):
        generated_tokens = self.model.generate(input_ids, decoder_start_token_id = self.tokenizer.lang_code_to_id[decoder_start_token])
        generated_sentences = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
        ground_truth_sentences = self.tokenizer.batch_decode(labels, skip_special_tokens=True)[0]
        return generated_sentences, ground_truth_sentences