In [1]:
import os
import sys
sys.path.append(os.getcwd() + '/..')
import torch
import pandas as pd
import numpy as np
import re
from datetime import datetime
from info_nce import InfoNCE
from transformers import AutoModel, AutoTokenizer, get_scheduler
from tqdm.notebook import tqdm
from datasets import Dataset
from torch.utils.data import DataLoader
from torch.optim import AdamW
from simcse import SimCSE

train_path = './../../data/ret/train_cover4_neg10.tsv'
eval_path = './../../data/ret/test_cover4_neg10.tsv'

device = 'cuda' if torch.cuda.is_available() else 'cpu'
NUM_EPOCHS = 6
LR = 0.05
BATCH_SIZE = 1
loss_temperature = 0.1
reSplitTokenset = re.compile(r'\'((?:\w+, )*\w+)\'(?:,|$)')
reSplitLabel = re.compile(r'[\'\"](.*?)[\'\"](?:,|$)')
train_data = pd.read_csv(train_path,sep='\t')
eval_data = pd.read_csv(eval_path,sep='\t')
train_data['negatives_label'] = train_data['negatives_label'].apply(lambda row: re.findall(reSplitLabel,row[1:-1]))
eval_data['negatives_label'] = eval_data['negatives_label'].apply(lambda row: re.findall(reSplitLabel,row[1:-1]))
np.random.seed(114514)
torch.manual_seed(114514)

<torch._C.Generator at 0x7fcc01699a70>

In [3]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def load_model(checkpoint):
    tokenizer = AutoTokenizer.from_pretrained(checkpoint,model_max_length = 40)
    model = AutoModel.from_pretrained(checkpoint).to(device)
    print(f'The model has {count_parameters(model)} trainable parameters')
    return model, tokenizer

In [5]:
model_path = '/data2T/jingchuan/untuned/flan-t5-base/' 
model, tokenizer = load_model(model_path)
model_name = model_path.split('/')[-2]
model_type = str(type(model)).split('.')[-1][:-2]

The model has 222903552 trainable parameters


In [2]:
def tokenize(examples):
    textinput = [examples['query_label']] + [examples['positive_label']] + examples['negatives_label']
    return tokenizer(textinput, padding=True, truncation=False)


In [4]:
def train_and_eval(model, train_iterator, eval_iterator, lossfunc, num_epochs, optimizer, lr_scheduler):
    loss_history = pd.DataFrame({'Epoch':[],'Training loss':[],'Evaluation loss':[]}).set_index('Epoch')
    n_train = len(train_iterator)
    n_eval = len(eval_iterator)
    embed_dim = model.pooler.dense.out_features
    with tqdm(total = num_epochs, position = 0, desc = 'Epoch') as outer:
        with tqdm(total = n_train, position = 1, leave = False, desc = 'Training') as inner1:
            with tqdm(total = n_eval, position = 2, leave = False, desc = 'Evaluating') as inner2:
                for e in range(num_epochs):
                    l_train = 0
                    model.train()
                    for batch in train_iterator:
                        batch_size = batch['input_ids'].shape[0]
                        seq_len = batch['input_ids'].shape[2]
                        batch = {k:v.reshape((1,-1,seq_len)).squeeze() for k,v in batch.items()}
                        outputs = model(**batch).pooler_output.reshape((batch_size,-1,embed_dim))
                        query_embed = outputs[:,0,:]
                        pos_embed = outputs[:,1,:]
                        neg_embeds = outputs[:,2:,:]
                        loss = lossfunc(query_embed,pos_embed,neg_embeds)
                        loss.backward()
                        optimizer.step()
                        lr_scheduler.step()
                        optimizer.zero_grad()
                        l_train += loss.detach().item()
                        inner1.update(1)
                    l_eval = 0
                    model.eval()
                    for batch in eval_iterator:
                        batch_size = batch['input_ids'].shape[0]
                        seq_len = batch['input_ids'].shape[2]
                        batch = {k:v.reshape((1,-1,seq_len)).squeeze() for k,v in batch.items()}
                        with torch.no_grad():
                            outputs = model(**batch).pooler_output.reshape((batch_size,-1,embed_dim))
                        query_embed = outputs[:,0,:]
                        pos_embed = outputs[:,1,:]
                        neg_embeds = outputs[:,2:,:]
                        loss = lossfunc(query_embed,pos_embed,neg_embeds)
                        l_eval += loss.detach().item()
                        inner2.update(1)
                    inner1.reset()
                    inner2.reset()
                    loss_history.loc[e+1] = {'Training loss':l_train / n_train, 'Evaluation loss':l_eval / n_eval}
                    display(loss_history)
                    outer.update(1)
    return loss_history

In [6]:
train_dataset = Dataset.from_pandas(train_data).map(lambda row: tokenize(row))
eval_dataset = Dataset.from_pandas(eval_data).map(lambda row: tokenize(row))
train_dataset.set_format(type="torch", columns=(["input_ids", "token_type_ids", "attention_mask"] if (model_type in ['BertModel']) else ["input_ids", "attention_mask"]), device=device)
eval_dataset.set_format(type="torch", columns=(["input_ids", "token_type_ids", "attention_mask"] if (model_type in ['BertModel']) else ["input_ids", "attention_mask"]),device=device)

Map:   0%|          | 0/38562 [00:00<?, ? examples/s]

Map:   0%|          | 0/2032 [00:00<?, ? examples/s]

In [7]:
train_loader = DataLoader(train_dataset,shuffle=False,batch_size=BATCH_SIZE)
eval_loader = DataLoader(eval_dataset,shuffle=False,batch_size=BATCH_SIZE)
optimizer = AdamW(model.parameters(), lr=LR)
lr_scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=1000, num_training_steps=NUM_EPOCHS * len(train_loader))
lossfunc = InfoNCE(negative_mode='paired',reduction='mean',temperature=loss_temperature)

In [None]:
loss_history = train_and_eval(model, train_loader, eval_loader, lossfunc, NUM_EPOCHS, optimizer, lr_scheduler)
now = datetime.now()
timestr = now.strftime('%Y%m%d-%H%M')
model.save_pretrained(f'/data2T/jingchuan/tuned_models/ret/{model_name}_{timestr}')
tokenizer.save_pretrained(f'/{model_name}_{timestr}')

In [9]:
tokenizer.batch_decode(train_dataset[0]['input_ids'])

['Collectible Advertising</s><pad><pad><pad><pad><pad><pad><pad><pad><pad>',
 'Collectible Metalware</s><pad><pad><pad><pad><pad><pad><pad><pad>',
 'Other Medical & Lab Equipment</s><pad><pad><pad><pad><pad><pad><pad>',
 'First Aid Ointments, Creams & Oils</s>',
 "Women's Golf Belts</s><pad><pad><pad><pad><pad><pad><pad>",
 'Travel Flight Socks</s><pad><pad><pad><pad><pad><pad><pad><pad>',
 'Commercial Truck Exhaust Manifolds</s><pad><pad><pad><pad>',
 'US Dollar Coins Mixed Lots</s><pad><pad><pad><pad><pad><pad><pad>',
 'Audio/Video Media Repair Equipment</s><pad><pad><pad><pad><pad><pad><pad>',
 'Micropets</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>',
 'Metaphysical Runes</s><pad><pad><pad><pad><pad><pad><pad><pad>',
 'Latch Hooking Kits</s><pad><pad><pad><pad><pad><pad><pad>']