In [None]:
!pip install transformers

In [None]:
import pandas as pd
import numpy as np
from zipfile import ZipFile
from sklearn.model_selection import train_test_split
from typing import Dict
import torch
from torch.utils.data import DataLoader,Dataset
import torch.nn.functional as F
from sklearn.metrics import f1_score
from operator import itemgetter
from sklearn.metrics import precision_score
import pickle
import time
from tqdm import tqdm

from transformers import AutoModel
from transformers import AutoTokenizer

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
language_model_name = "xlm-roberta-base"
tokenizer = AutoTokenizer.from_pretrained(language_model_name)

In [None]:
# utility functions
def separate_predicates(predicates_list):
  index_predicate=[]
  for i,j in enumerate(predicates_list):
    if j is not '_':
      index_predicate.append(i) 
  return index_predicate  

def get_dict(d,k):
  return d.get(str(k))

def role_label(lis):
  return [label_to_id[i] for i in lis]

def role_pos(lis):
    return [pos_to_id[i] for i in lis]

def add_predicate(lis,predicates,k):
  predicate_sense=predicates[k]
  if predicate_sense in label_to_id.keys():
      lis[k]=label_to_id[predicate_sense]
  return lis

def id_to_onehotencode(k,length):
    temp=torch.zeros(length)
    temp[k]=1
    return temp
    


In [None]:
class TranDataset(torch.utils.data.Dataset):
    def __init__(self, df):
        self.text= df[['lemmas']].values.tolist()
        self.labels=df[['roles']].values.tolist()  
        self.predicate_loc=df.index_predicate.values.tolist()
        self.pos_tag=df[['pos_tags']].values.tolist() 
        
    def __len__(self):
        return len(self.text)
  
    def __getitem__(self,idx):
        return self.text[idx],self.labels[idx],self.predicate_loc[idx],self.pos_tag[idx]
    


In [None]:
#preprocess inputs
# use_predicate : True for using predicate sense information while training
# use_predicate : False for using predicate location information while training
use_predicate=False
def collate_fn(batch):
    #get input_ids and attenstion mask for transformer network
    batch_out = tokenizer(
        [sentence[0][0] for sentence in batch],
        return_tensors="pt",
        padding=True,
        is_split_into_words=True,
    )
    # modify labels,predicate_location,pos_tags according to output from tokenizer
    labels = []
    predicate_loc = []
    pos_tag = []
    srl_tags = [sentence[1][0] for sentence in batch]
    predicate_tag=[sentence[2] for sentence in batch]
    pos =[sentence[3][0] for sentence in batch]
    for i, label in enumerate(zip(srl_tags,predicate_tag,pos)):
        #convert predictae position to onehot_encode 
        predicate_onehot=id_to_onehotencode(label[1],len(label[0]))
      # obtains the word_ids of the i-th sentence
        word_ids = batch_out.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        predicate_ids = []
        pos_ids = []
        for word_idx in word_ids:
        # Special tokens have a word id that is None. We set the label to -100 so they are automatically
        # ignored in the loss function.
            if word_idx is None:
                pos_ids.append(0)
                label_ids.append(-100)
                predicate_ids.append(0)
        # We set the label for the first token of each word.
            elif word_idx != previous_word_idx:
                pos_ids.append(label[2][word_idx])
                label_ids.append(label[0][word_idx])
                #use predicate sense if true,else just use location information
                if use_predicate:
                    predicate_ids.append(predicate_onehot[word_idx]*label[0][word_idx])
                else:
                    predicate_ids.append(predicate_onehot[word_idx])
        # For the other tokens in a word, we set the label same to as above
            else:
                pos_ids.append(label[2][word_idx])
                label_ids.append(label[0][word_idx])
                #use predicate sense if true,else just use location information
                if use_predicate:
                    predicate_ids.append(predicate_onehot[word_idx]*label[0][word_idx])
                else:
                    predicate_ids.append(predicate_onehot[word_idx])
            previous_word_idx = word_idx
        labels.append(label_ids)
        predicate_loc.append(predicate_ids)
        pos_tag.append(pos_ids)
    
    # pad the labels with -100
    batch_max_length = len(max(labels, key=len))
    labels = [l + ([-100] * abs(batch_max_length - len(l))) for l in labels]
    pos_tag = [l + ([0] * abs(batch_max_length - len(l))) for l in pos_tag]
    batch_out["labels"] = torch.as_tensor(labels)
    batch_out["predicate_loc"] = torch.as_tensor(predicate_loc)[:,:,None]
    batch_out["pos_tag"] = F.one_hot(torch.as_tensor(pos_tag),18)
    
    return batch_out

In [None]:
data=pd.read_json('/kaggle/input/data-srl-nlp-modified/data_hw2/data_hw2/EN/train.json').transpose(copy=True)
data["index_predicate"]=data['predicates'].apply(lambda x:separate_predicates(x))
data=data.explode('index_predicate')
data.dropna(inplace=True)
data['roles'] = data.apply(lambda x: get_dict(x['roles'], x['index_predicate']), axis=1)
data.dropna(inplace=True)
label_to_id = {n: i for i, n in enumerate(data.explode('roles').dropna()['roles'].unique())}
predicate_to_id={n: i+len(label_to_id)-1 for i, n in enumerate(data.explode('predicates').dropna()['predicates'].unique()) if '_' !=n}
label_to_id.update(predicate_to_id)
id_to_label = {i: n for n, i in label_to_id.items()}
pos_to_id = {n: i+1 for i, n in enumerate(data.explode('pos_tags')['pos_tags'].unique())}
data['roles']=data.roles.apply(lambda x: role_label(x))
data['roles'] = data.apply(lambda x: add_predicate(x['roles'],x['predicates'], x['index_predicate']), axis=1)
data['pos_tags']=data.pos_tags.apply(lambda x: role_pos(x))

In [105]:
pos_to_id

{'NOUN': 1,
 'ADV': 2,
 'VERB': 3,
 'SCONJ': 4,
 'DET': 5,
 'ADJ': 6,
 'ADP': 7,
 'PUNCT': 8,
 'PROPN': 9,
 'PART': 10,
 'NUM': 11,
 'CCONJ': 12,
 'AUX': 13,
 'PRON': 14,
 'SYM': 15,
 'X': 16,
 'INTJ': 17}

In [106]:
label_to_id

{'agent': 0,
 '_': 1,
 'theme': 2,
 'beneficiary': 3,
 'patient': 4,
 'topic': 5,
 'goal': 6,
 'recipient': 7,
 'co-theme': 8,
 'result': 9,
 'stimulus': 10,
 'experiencer': 11,
 'destination': 12,
 'value': 13,
 'attribute': 14,
 'location': 15,
 'source': 16,
 'cause': 17,
 'co-agent': 18,
 'time': 19,
 'co-patient': 20,
 'product': 21,
 'purpose': 22,
 'instrument': 23,
 'extent': 24,
 'asset': 25,
 'material': 26,
 'ASK_REQUEST': 27,
 'BENEFIT_EXPLOIT': 28,
 'PLAN_SCHEDULE': 29,
 'CARRY-OUT-ACTION': 30,
 'ESTABLISH': 31,
 'SIMPLIFY': 32,
 'PROPOSE': 33,
 'TAKE-INTO-ACCOUNT_CONSIDER': 34,
 'BEGIN': 35,
 'CIRCULATE_SPREAD_DISTRIBUTE': 36,
 'REFER': 37,
 'SHOW': 38,
 'PRECLUDE_FORBID_EXPEL': 39,
 'VIOLATE': 40,
 'VERIFY': 41,
 'CAUSE-SMT': 42,
 'ABSTAIN_AVOID_REFRAIN': 43,
 'TRANSMIT': 44,
 'SEE': 45,
 'SUMMON': 46,
 'GUARANTEE_ENSURE_PROMISE': 47,
 'RECEIVE': 48,
 'INCREASE_ENLARGE_MULTIPLY': 49,
 'DECREE_DECLARE': 50,
 'PAY': 51,
 'CAUSE-MENTAL-STATE': 52,
 'CAGE_IMPRISON': 53,
 'HU

In [None]:
train, val = train_test_split(data, test_size=0.2)

In [None]:
test_org=pd.read_json('/kaggle/input/data-srl-nlp-modified/data_hw2/data_hw2/EN/dev.json').transpose(copy=True)
test_org["index_predicate"]=test_org['predicates'].apply(lambda x:separate_predicates(x))
test=test_org.explode('index_predicate')
test.dropna(inplace=True)
test['roles'] = test.apply(lambda x: get_dict(x['roles'], x['index_predicate']), axis=1)
test.dropna(inplace=True)
test['roles']=test.roles.apply(lambda x: role_label(x))
test['roles'] = test.apply(lambda x: add_predicate(x['roles'],x['predicates'], x['index_predicate']), axis=1)
test['pos_tags']=test.pos_tags.apply(lambda x: role_pos(x))

In [None]:
fr_data=pd.read_json('/kaggle/input/data-srl-nlp-modified/data_hw2/data_hw2/FR/train.json').transpose(copy=True)
fr_data["index_predicate"]=fr_data['predicates'].apply(lambda x:separate_predicates(x))
fr_data=fr_data.explode('index_predicate')
fr_data.dropna(inplace=True)
fr_data['roles'] = fr_data.apply(lambda x: get_dict(x['roles'], x['index_predicate']), axis=1)
fr_data.dropna(inplace=True)
fr_data['roles']=fr_data.roles.apply(lambda x: role_label(x))
fr_data['roles'] = fr_data.apply(lambda x: add_predicate(x['roles'],x['predicates'], x['index_predicate']), axis=1)
fr_data['pos_tags']=fr_data.pos_tags.apply(lambda x: role_pos(x))

In [None]:
fr_train, fr_val = train_test_split(fr_data, test_size=0.2)

In [None]:
fr_test=pd.read_json('/kaggle/input/data-srl-nlp-modified/data_hw2/data_hw2/FR/dev.json').transpose(copy=True)
fr_test["index_predicate"]=fr_test['predicates'].apply(lambda x:separate_predicates(x))
fr_test=fr_test.explode('index_predicate')
fr_test.dropna(inplace=True)
fr_test['roles'] = fr_test.apply(lambda x: get_dict(x['roles'], x['index_predicate']), axis=1)
fr_test.dropna(inplace=True)
fr_test['roles']=fr_test.roles.apply(lambda x: role_label(x))
fr_test['roles'] = fr_test.apply(lambda x: add_predicate(x['roles'],x['predicates'], x['index_predicate']), axis=1)
fr_test['pos_tags']=fr_test.pos_tags.apply(lambda x: role_pos(x))

In [None]:
es_data=pd.read_json('/kaggle/input/data-srl-nlp-modified/data_hw2/data_hw2/ES/train.json').transpose(copy=True)
es_data["index_predicate"]=es_data['predicates'].apply(lambda x:separate_predicates(x))
es_data=es_data.explode('index_predicate')
es_data.dropna(inplace=True)
es_data['roles'] = es_data.apply(lambda x: get_dict(x['roles'], x['index_predicate']), axis=1)
es_data.dropna(inplace=True)
es_data['roles']=es_data.roles.apply(lambda x: role_label(x))
es_data['roles'] = es_data.apply(lambda x: add_predicate(x['roles'],x['predicates'], x['index_predicate']), axis=1)
es_data['pos_tags']=es_data.pos_tags.apply(lambda x: role_pos(x))

In [None]:
es_train, es_val = train_test_split(es_data, test_size=0.2)

In [None]:
es_test=pd.read_json('/kaggle/input/data-srl-nlp-modified/data_hw2/data_hw2/ES/dev.json').transpose(copy=True)
es_test["index_predicate"]=es_test['predicates'].apply(lambda x:separate_predicates(x))
es_test=es_test.explode('index_predicate')
es_test.dropna(inplace=True)
es_test['roles'] = es_test.apply(lambda x: get_dict(x['roles'], x['index_predicate']), axis=1)
es_test.dropna(inplace=True)
es_test['roles']=es_test.roles.apply(lambda x: role_label(x))
es_test['roles'] = es_test.apply(lambda x: add_predicate(x['roles'],x['predicates'], x['index_predicate']), axis=1)
es_test['pos_tags']=es_test.pos_tags.apply(lambda x: role_pos(x))

In [None]:
train_dataset=TranDataset(data)
val_dataset=TranDataset(val)
test_dataset=TranDataset(test)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True,collate_fn=collate_fn)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=False,collate_fn=collate_fn)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False,collate_fn=collate_fn)
##########FRENCH#############
fr_train_dataset=TranDataset(fr_data)
fr_val_dataset=TranDataset(fr_val)
fr_test_dataset=TranDataset(fr_test)
fr_train_dataloader = torch.utils.data.DataLoader(fr_train_dataset, batch_size=64, shuffle=True,collate_fn=collate_fn)
fr_val_dataloader = torch.utils.data.DataLoader(fr_val_dataset, batch_size=64, shuffle=False,collate_fn=collate_fn)
fr_test_dataloader = torch.utils.data.DataLoader(fr_test_dataset, batch_size=64, shuffle=False,collate_fn=collate_fn)
##########ESP#############
es_train_dataset=TranDataset(es_data)
es_val_dataset=TranDataset(es_val)
es_test_dataset=TranDataset(es_test)
es_train_dataloader = torch.utils.data.DataLoader(es_train_dataset, batch_size=64, shuffle=True,collate_fn=collate_fn)
es_val_dataloader = torch.utils.data.DataLoader(es_val_dataset, batch_size=64, shuffle=False,collate_fn=collate_fn)
es_test_dataloader = torch.utils.data.DataLoader(es_test_dataset, batch_size=64, shuffle=False,collate_fn=collate_fn)


In [None]:
# testing the input which goes into model visually
kk=2
for i in train_dataloader:
    print(i)
    print(i.input_ids.shape)
    print(i.labels.shape)
    print(i.predicate_loc.shape)
    print(i.pos_tag.shape)
    break

In [None]:
class SRLModel(torch.nn.Module):
    def __init__(self,num_labels: int, fine_tune_lm: bool = True, *args, **kwargs) -> None:
        super().__init__()
        self.num_labels = num_labels
        # layers
        self.transformer_model = AutoModel.from_pretrained('xlm-roberta-base', output_hidden_states=True)
        if not fine_tune_lm:
            for param in self.transformer_model.parameters():
                param.requires_grad = False
        self.dropout = torch.nn.Dropout(0.5)
        self.lstm = torch.nn.LSTM(input_size=self.transformer_model.config.hidden_size+1+18, 
                            hidden_size=800, 
                            batch_first=True,
                            bidirectional=True,
                            num_layers=2,
                            dropout=0.5
                                 )
        self.lstm_output_dim = 800 * 2
        self.relu = torch.nn.ReLU()
        self.fc1 = torch.nn.Linear(self.lstm_output_dim, 512)
        self.classifier = torch.nn.Linear(512, num_labels)
        

    def forward(self, batch):
        input_ids = batch['input_ids'].to(device)
        attention_mask=batch['attention_mask'].to(device)
        predicate_loc = batch['predicate_loc'].to(device)
        pos_tags = batch['pos_tag'].to(device)
        transformers_outputs = self.transformer_model(input_ids,attention_mask)
        embed_out = torch.stack(transformers_outputs.hidden_states[-4:], dim=0).sum(dim=0)
        concat_predicate_loc=torch.cat((embed_out,pos_tags,predicate_loc), 2)
        output,_ = self.lstm(concat_predicate_loc)
        fc1 = self.fc1(self.dropout(output))
        fc1=self.relu(fc1)
        preds = self.classifier(self.dropout(fc1))
        return preds

In [None]:
# pre_trained_transformer_model = AutoModel.from_pretrained(language_model_name, output_hidden_states=True)
# model=SRLModel(len(label_to_id),fine_tune_lm=False).to(device)
# optimizer = torch.optim.Adam(model.parameters())
# criterion = torch.nn.CrossEntropyLoss(ignore_index = -100)

In [None]:
def train_model(model, iterator, optimizer, criterion):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.train()
    
    for batch in tqdm(iterator):
        
        text = batch
        tags = batch['labels']
        
        optimizer.zero_grad()
        
        #text = [sent len, batch size]
        
        predictions = model(text)
        
        #predictions = [sent len, batch size, output dim]
        #tags = [sent len, batch size]
        
        predictions = predictions.view(-1, predictions.shape[-1])
        tags = tags.view(-1).type(torch.LongTensor).to(device)
        #predictions = [sent len * batch size, output dim]
        #tags = [sent len * batch size]
        
        loss = criterion(predictions, tags)
        
        loss.backward()
        
        optimizer.step()
        
        epoch_loss += loss.item()
        
        
    return epoch_loss / len(iterator)

def evaluate(model, iterator, criterion):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.eval()
    
    with torch.no_grad():
    
        for batch in iterator:

            text = batch
            tags = batch['labels']
            
            predictions = model(text)
            
            predictions = predictions.view(-1, predictions.shape[-1])
            tags = tags.view(-1).type(torch.LongTensor).to(device)
            

            
            loss = criterion(predictions, tags)
            
            epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

def predict(model, iterator):
    pred=[]
    tag=[]
    model.eval()
    
    with torch.no_grad():
        for batch in iterator:

            text = batch
            tags = batch['labels']
            
            predictions = model(text)
            
            predictions = predictions.view(-1, predictions.shape[-1])
            max_preds = predictions.argmax(dim = 1, keepdim = False)
            tags = tags.view(-1)

            pred.append(max_preds.tolist())
            tag.append(tags.tolist())
            
        
    return pred,tag

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [6]:
# # load pretrained model
load_pretrained_model=True
if load_pretrained_model==True:
    model_english = SRLModel(len(label_to_id),fine_tune_lm=False).to(device)
    model_english.load_state_dict(torch.load('/kaggle/working/final_eng_hw2_234.pt', map_location="cuda:0"))  # Choose whatever GPU device number you want
    optimizer = torch.optim.Adam(model_english.parameters())
    criterion = torch.nn.CrossEntropyLoss(ignore_index = -100)

Some weights of the model checkpoint at xlm-roberta-base were not used when initializing XLMRobertaModel: ['lm_head.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.weight', 'lm_head.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing XLMRobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
N_EPOCHS =10

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    start_time = time.time()
    
    train_loss= train_model(model_english, train_dataloader, optimizer, criterion)
    valid_loss= evaluate(model_english, test_dataloader, criterion)
    
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'modelhw2v1-model.pt')
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f}')
    

In [69]:
#test on test dataset for either all arguments
pred,label=predict(model_english,fr_test_dataloader)
pred = [item for sublist in pred for item in sublist]
label = [item for sublist in label for item in sublist]
# modify to get a range
# get resulst only for predicate
all_labels = list(label_to_id.values())
exclude = [i for i in range(0,27)]
# eval_args = [label for label in all_labels if label not in exclude]
# get resulst only for arguments
eval_args=list(itemgetter('agent', 'theme', 'beneficiary', 'patient', 'topic', 'goal', 'recipient', 
                 'co-theme', 'result', 'stimulus', 'experiencer', 'destination', 'value', 
                 'attribute', 'location', 'source', 'cause', 'co-agent', 'time', 'co-patient',
                 'product', 'purpose', 'instrument', 'extent', 'asset', 'material')(label_to_id))
f1_score(label,pred,labels=eval_args,average='micro')

0.5875091366848046

In [None]:
torch.save(model_english.state_dict(), 'final_fr_hw2_234.pt')

# Fine Tune on French and Spanish

In [92]:
class Identity(torch.nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x

In [103]:
class SRLModelEsp(torch.nn.Module):
    def __init__(self, base_model, num_labels: int, *args, **kwargs):
        super().__init__()
        self.base_model = base_model
        self.base_model.classifier=Identity()
#         freeze the base model
        for param in self.base_model.parameters():
            param.requires_grad = False
        self.num_labels = num_labels
        self.classifier_es = torch.nn.Linear(512, num_labels)
        
    def forward(self, batch):
        preds = self.base_model(batch)
        preds = self.classifier_es(preds)
        return preds

SRL_model_es = SRLModelEsp(model_english, len(label_to_id)).to(device)
optimizer = torch.optim.Adam(SRL_model_es.parameters())
criterion = torch.nn.CrossEntropyLoss(ignore_index = -100)

In [None]:
N_EPOCHS = 10

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    start_time = time.time()
    
    train_loss= train_model(SRL_model_es, es_train_dataloader, optimizer, criterion)
    valid_loss= evaluate(SRL_model_es, es_val_dataloader, criterion)
    
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(SRL_model_french.state_dict(), 'modelhw2v1_fr-model.pt')
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f}')

In [109]:
#test on test dataset for either all arguments
pred,label=predict(SRL_model_es,es_test_dataloader)
pred = [item for sublist in pred for item in sublist]
label = [item for sublist in label for item in sublist]
# modify to get a range
# get resulst only for predicate
all_labels = list(label_to_id.values())
# exclude = [i for i in range(0,27)]
# eval_args = [label for label in all_labels if label not in exclude]
# get resulst only for arguments
eval_args=list(itemgetter('agent', 'theme', 'beneficiary', 'patient', 'topic', 'goal', 'recipient', 
                 'co-theme', 'result', 'stimulus', 'experiencer', 'destination', 'value', 
                 'attribute', 'location', 'source', 'cause', 'co-agent', 'time', 'co-patient',
                 'product', 'purpose', 'instrument', 'extent', 'asset', 'material')(label_to_id))
f1_score(label,pred,labels=eval_args,average='micro')

0.6287413660782809

In [110]:
torch.save(SRL_model_es.state_dict(), 'final_es_hw_234.pt')

In [8]:
class SRLModelFrench(torch.nn.Module):
    def __init__(self, base_model, num_labels: int, *args, **kwargs):
        super().__init__()
        self.base_model = base_model
        self.base_model.classifier=Identity()
#         freeze the base model
        for param in self.base_model.parameters():
            param.requires_grad = False
        self.num_labels = num_labels
        self.classifier_fr = torch.nn.Linear(512, num_labels)
        
    def forward(self, batch):
        preds = self.base_model(batch)
        preds = self.classifier_fr(preds)
        return preds

SRL_model_french = SRLModelFrench(model_english, len(label_to_id)).to(device)
optimizer = torch.optim.Adam(SRL_model_french.parameters())
criterion = torch.nn.CrossEntropyLoss(ignore_index = -100)

In [9]:
N_EPOCHS = 100

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    start_time = time.time()
    
    train_loss= train_model(SRL_model_french, fr_train_dataloader, optimizer, criterion)
    valid_loss= evaluate(SRL_model_french, fr_val_dataloader, criterion)
    
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(SRL_model_french.state_dict(), 'modelhw2v1_fr-model.pt')
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f}')

100%|██████████| 18/18 [00:07<00:00,  2.47it/s]


Epoch: 01 | Epoch Time: 0m 8s
	Train Loss: 2.573
	 Val. Loss: 0.346


100%|██████████| 18/18 [00:07<00:00,  2.51it/s]


Epoch: 02 | Epoch Time: 0m 8s
	Train Loss: 0.336
	 Val. Loss: 0.262


100%|██████████| 18/18 [00:07<00:00,  2.53it/s]


Epoch: 03 | Epoch Time: 0m 8s
	Train Loss: 0.282
	 Val. Loss: 0.228


100%|██████████| 18/18 [00:07<00:00,  2.53it/s]


Epoch: 04 | Epoch Time: 0m 8s
	Train Loss: 0.257
	 Val. Loss: 0.209


100%|██████████| 18/18 [00:07<00:00,  2.50it/s]


Epoch: 05 | Epoch Time: 0m 8s
	Train Loss: 0.241
	 Val. Loss: 0.196


100%|██████████| 18/18 [00:06<00:00,  2.58it/s]


Epoch: 06 | Epoch Time: 0m 8s
	Train Loss: 0.231
	 Val. Loss: 0.187


100%|██████████| 18/18 [00:07<00:00,  2.50it/s]


Epoch: 07 | Epoch Time: 0m 8s
	Train Loss: 0.225
	 Val. Loss: 0.180


100%|██████████| 18/18 [00:07<00:00,  2.43it/s]


Epoch: 08 | Epoch Time: 0m 8s
	Train Loss: 0.215
	 Val. Loss: 0.175


100%|██████████| 18/18 [00:07<00:00,  2.53it/s]


Epoch: 09 | Epoch Time: 0m 8s
	Train Loss: 0.213
	 Val. Loss: 0.170


100%|██████████| 18/18 [00:06<00:00,  2.62it/s]


Epoch: 10 | Epoch Time: 0m 8s
	Train Loss: 0.207
	 Val. Loss: 0.166


100%|██████████| 18/18 [00:07<00:00,  2.44it/s]


Epoch: 11 | Epoch Time: 0m 8s
	Train Loss: 0.204
	 Val. Loss: 0.163


100%|██████████| 18/18 [00:07<00:00,  2.50it/s]


Epoch: 12 | Epoch Time: 0m 8s
	Train Loss: 0.199
	 Val. Loss: 0.160


100%|██████████| 18/18 [00:07<00:00,  2.53it/s]


Epoch: 13 | Epoch Time: 0m 8s
	Train Loss: 0.194
	 Val. Loss: 0.157


100%|██████████| 18/18 [00:07<00:00,  2.49it/s]


Epoch: 14 | Epoch Time: 0m 8s
	Train Loss: 0.193
	 Val. Loss: 0.154


100%|██████████| 18/18 [00:07<00:00,  2.44it/s]


Epoch: 15 | Epoch Time: 0m 8s
	Train Loss: 0.190
	 Val. Loss: 0.153


100%|██████████| 18/18 [00:07<00:00,  2.55it/s]


Epoch: 16 | Epoch Time: 0m 8s
	Train Loss: 0.188
	 Val. Loss: 0.151


100%|██████████| 18/18 [00:07<00:00,  2.56it/s]


Epoch: 17 | Epoch Time: 0m 8s
	Train Loss: 0.188
	 Val. Loss: 0.149


100%|██████████| 18/18 [00:07<00:00,  2.50it/s]


Epoch: 18 | Epoch Time: 0m 8s
	Train Loss: 0.184
	 Val. Loss: 0.147


100%|██████████| 18/18 [00:07<00:00,  2.51it/s]


Epoch: 19 | Epoch Time: 0m 8s
	Train Loss: 0.181
	 Val. Loss: 0.146


100%|██████████| 18/18 [00:07<00:00,  2.56it/s]


Epoch: 20 | Epoch Time: 0m 8s
	Train Loss: 0.180
	 Val. Loss: 0.145


100%|██████████| 18/18 [00:07<00:00,  2.53it/s]


Epoch: 21 | Epoch Time: 0m 8s
	Train Loss: 0.179
	 Val. Loss: 0.143


100%|██████████| 18/18 [00:07<00:00,  2.51it/s]


Epoch: 22 | Epoch Time: 0m 8s
	Train Loss: 0.177
	 Val. Loss: 0.142


100%|██████████| 18/18 [00:07<00:00,  2.50it/s]


Epoch: 23 | Epoch Time: 0m 8s
	Train Loss: 0.176
	 Val. Loss: 0.141


100%|██████████| 18/18 [00:07<00:00,  2.50it/s]


Epoch: 24 | Epoch Time: 0m 8s
	Train Loss: 0.176
	 Val. Loss: 0.140


100%|██████████| 18/18 [00:07<00:00,  2.48it/s]


Epoch: 25 | Epoch Time: 0m 8s
	Train Loss: 0.175
	 Val. Loss: 0.139


100%|██████████| 18/18 [00:07<00:00,  2.44it/s]


Epoch: 26 | Epoch Time: 0m 8s
	Train Loss: 0.174
	 Val. Loss: 0.138


100%|██████████| 18/18 [00:06<00:00,  2.58it/s]


Epoch: 27 | Epoch Time: 0m 8s
	Train Loss: 0.174
	 Val. Loss: 0.137


100%|██████████| 18/18 [00:07<00:00,  2.49it/s]


Epoch: 28 | Epoch Time: 0m 8s
	Train Loss: 0.172
	 Val. Loss: 0.136


100%|██████████| 18/18 [00:07<00:00,  2.45it/s]


Epoch: 29 | Epoch Time: 0m 8s
	Train Loss: 0.173
	 Val. Loss: 0.135


100%|██████████| 18/18 [00:06<00:00,  2.61it/s]


Epoch: 30 | Epoch Time: 0m 8s
	Train Loss: 0.172
	 Val. Loss: 0.135


100%|██████████| 18/18 [00:07<00:00,  2.54it/s]


Epoch: 31 | Epoch Time: 0m 8s
	Train Loss: 0.170
	 Val. Loss: 0.133


100%|██████████| 18/18 [00:07<00:00,  2.47it/s]


Epoch: 32 | Epoch Time: 0m 8s
	Train Loss: 0.168
	 Val. Loss: 0.133


100%|██████████| 18/18 [00:07<00:00,  2.49it/s]


Epoch: 33 | Epoch Time: 0m 8s
	Train Loss: 0.168
	 Val. Loss: 0.132


100%|██████████| 18/18 [00:07<00:00,  2.53it/s]


Epoch: 34 | Epoch Time: 0m 8s
	Train Loss: 0.168
	 Val. Loss: 0.132


100%|██████████| 18/18 [00:07<00:00,  2.54it/s]


Epoch: 35 | Epoch Time: 0m 8s
	Train Loss: 0.167
	 Val. Loss: 0.130


100%|██████████| 18/18 [00:07<00:00,  2.47it/s]


Epoch: 36 | Epoch Time: 0m 8s
	Train Loss: 0.167
	 Val. Loss: 0.130


100%|██████████| 18/18 [00:07<00:00,  2.46it/s]


Epoch: 37 | Epoch Time: 0m 8s
	Train Loss: 0.163
	 Val. Loss: 0.130


100%|██████████| 18/18 [00:07<00:00,  2.48it/s]


Epoch: 38 | Epoch Time: 0m 8s
	Train Loss: 0.169
	 Val. Loss: 0.129


100%|██████████| 18/18 [00:07<00:00,  2.55it/s]


Epoch: 39 | Epoch Time: 0m 8s
	Train Loss: 0.165
	 Val. Loss: 0.129


100%|██████████| 18/18 [00:07<00:00,  2.56it/s]


Epoch: 40 | Epoch Time: 0m 8s
	Train Loss: 0.167
	 Val. Loss: 0.128


100%|██████████| 18/18 [00:07<00:00,  2.49it/s]


Epoch: 41 | Epoch Time: 0m 8s
	Train Loss: 0.164
	 Val. Loss: 0.128


100%|██████████| 18/18 [00:06<00:00,  2.57it/s]


Epoch: 42 | Epoch Time: 0m 8s
	Train Loss: 0.165
	 Val. Loss: 0.127


100%|██████████| 18/18 [00:06<00:00,  2.59it/s]


Epoch: 43 | Epoch Time: 0m 8s
	Train Loss: 0.161
	 Val. Loss: 0.127


100%|██████████| 18/18 [00:07<00:00,  2.57it/s]


Epoch: 44 | Epoch Time: 0m 8s
	Train Loss: 0.162
	 Val. Loss: 0.126


100%|██████████| 18/18 [00:07<00:00,  2.54it/s]


Epoch: 45 | Epoch Time: 0m 8s
	Train Loss: 0.160
	 Val. Loss: 0.126


100%|██████████| 18/18 [00:07<00:00,  2.55it/s]


Epoch: 46 | Epoch Time: 0m 8s
	Train Loss: 0.159
	 Val. Loss: 0.126


100%|██████████| 18/18 [00:07<00:00,  2.53it/s]


Epoch: 47 | Epoch Time: 0m 8s
	Train Loss: 0.163
	 Val. Loss: 0.125


100%|██████████| 18/18 [00:06<00:00,  2.61it/s]


Epoch: 48 | Epoch Time: 0m 8s
	Train Loss: 0.160
	 Val. Loss: 0.125


100%|██████████| 18/18 [00:07<00:00,  2.49it/s]


Epoch: 49 | Epoch Time: 0m 8s
	Train Loss: 0.162
	 Val. Loss: 0.124


100%|██████████| 18/18 [00:07<00:00,  2.51it/s]


Epoch: 50 | Epoch Time: 0m 8s
	Train Loss: 0.160
	 Val. Loss: 0.124


100%|██████████| 18/18 [00:07<00:00,  2.46it/s]


Epoch: 51 | Epoch Time: 0m 8s
	Train Loss: 0.160
	 Val. Loss: 0.124


100%|██████████| 18/18 [00:07<00:00,  2.56it/s]


Epoch: 52 | Epoch Time: 0m 8s
	Train Loss: 0.161
	 Val. Loss: 0.123


100%|██████████| 18/18 [00:07<00:00,  2.53it/s]


Epoch: 53 | Epoch Time: 0m 8s
	Train Loss: 0.162
	 Val. Loss: 0.123


100%|██████████| 18/18 [00:07<00:00,  2.51it/s]


Epoch: 54 | Epoch Time: 0m 8s
	Train Loss: 0.159
	 Val. Loss: 0.122


100%|██████████| 18/18 [00:07<00:00,  2.47it/s]


Epoch: 55 | Epoch Time: 0m 8s
	Train Loss: 0.157
	 Val. Loss: 0.122


100%|██████████| 18/18 [00:07<00:00,  2.49it/s]


Epoch: 56 | Epoch Time: 0m 8s
	Train Loss: 0.159
	 Val. Loss: 0.121


100%|██████████| 18/18 [00:06<00:00,  2.59it/s]


Epoch: 57 | Epoch Time: 0m 8s
	Train Loss: 0.158
	 Val. Loss: 0.121


100%|██████████| 18/18 [00:07<00:00,  2.54it/s]


Epoch: 58 | Epoch Time: 0m 8s
	Train Loss: 0.157
	 Val. Loss: 0.121


100%|██████████| 18/18 [00:07<00:00,  2.50it/s]


Epoch: 59 | Epoch Time: 0m 8s
	Train Loss: 0.158
	 Val. Loss: 0.120


100%|██████████| 18/18 [00:07<00:00,  2.53it/s]


Epoch: 60 | Epoch Time: 0m 8s
	Train Loss: 0.159
	 Val. Loss: 0.120


100%|██████████| 18/18 [00:07<00:00,  2.53it/s]


Epoch: 61 | Epoch Time: 0m 8s
	Train Loss: 0.158
	 Val. Loss: 0.120


100%|██████████| 18/18 [00:06<00:00,  2.61it/s]


Epoch: 62 | Epoch Time: 0m 8s
	Train Loss: 0.156
	 Val. Loss: 0.119


100%|██████████| 18/18 [00:07<00:00,  2.50it/s]


Epoch: 63 | Epoch Time: 0m 8s
	Train Loss: 0.156
	 Val. Loss: 0.119


100%|██████████| 18/18 [00:07<00:00,  2.55it/s]


Epoch: 64 | Epoch Time: 0m 8s
	Train Loss: 0.159
	 Val. Loss: 0.119


100%|██████████| 18/18 [00:07<00:00,  2.54it/s]


Epoch: 65 | Epoch Time: 0m 8s
	Train Loss: 0.157
	 Val. Loss: 0.119


100%|██████████| 18/18 [00:07<00:00,  2.51it/s]


Epoch: 66 | Epoch Time: 0m 8s
	Train Loss: 0.156
	 Val. Loss: 0.118


100%|██████████| 18/18 [00:07<00:00,  2.53it/s]


Epoch: 67 | Epoch Time: 0m 8s
	Train Loss: 0.157
	 Val. Loss: 0.118


100%|██████████| 18/18 [00:07<00:00,  2.56it/s]


Epoch: 68 | Epoch Time: 0m 8s
	Train Loss: 0.154
	 Val. Loss: 0.119


100%|██████████| 18/18 [00:07<00:00,  2.45it/s]


Epoch: 69 | Epoch Time: 0m 8s
	Train Loss: 0.156
	 Val. Loss: 0.118


100%|██████████| 18/18 [00:07<00:00,  2.47it/s]


Epoch: 70 | Epoch Time: 0m 8s
	Train Loss: 0.155
	 Val. Loss: 0.117


100%|██████████| 18/18 [00:06<00:00,  2.59it/s]


Epoch: 71 | Epoch Time: 0m 8s
	Train Loss: 0.155
	 Val. Loss: 0.118


100%|██████████| 18/18 [00:07<00:00,  2.52it/s]


Epoch: 72 | Epoch Time: 0m 8s
	Train Loss: 0.155
	 Val. Loss: 0.118


100%|██████████| 18/18 [00:06<00:00,  2.57it/s]


Epoch: 73 | Epoch Time: 0m 8s
	Train Loss: 0.156
	 Val. Loss: 0.117


100%|██████████| 18/18 [00:07<00:00,  2.56it/s]


Epoch: 74 | Epoch Time: 0m 8s
	Train Loss: 0.156
	 Val. Loss: 0.117


100%|██████████| 18/18 [00:07<00:00,  2.54it/s]


Epoch: 75 | Epoch Time: 0m 8s
	Train Loss: 0.154
	 Val. Loss: 0.116


100%|██████████| 18/18 [00:07<00:00,  2.48it/s]


Epoch: 76 | Epoch Time: 0m 8s
	Train Loss: 0.153
	 Val. Loss: 0.116


100%|██████████| 18/18 [00:06<00:00,  2.58it/s]


Epoch: 77 | Epoch Time: 0m 8s
	Train Loss: 0.155
	 Val. Loss: 0.116


100%|██████████| 18/18 [00:06<00:00,  2.58it/s]


Epoch: 78 | Epoch Time: 0m 8s
	Train Loss: 0.154
	 Val. Loss: 0.115


100%|██████████| 18/18 [00:07<00:00,  2.47it/s]


Epoch: 79 | Epoch Time: 0m 8s
	Train Loss: 0.153
	 Val. Loss: 0.116


100%|██████████| 18/18 [00:07<00:00,  2.56it/s]


Epoch: 80 | Epoch Time: 0m 8s
	Train Loss: 0.152
	 Val. Loss: 0.116


100%|██████████| 18/18 [00:07<00:00,  2.51it/s]


Epoch: 81 | Epoch Time: 0m 8s
	Train Loss: 0.154
	 Val. Loss: 0.116


100%|██████████| 18/18 [00:07<00:00,  2.50it/s]


Epoch: 82 | Epoch Time: 0m 8s
	Train Loss: 0.154
	 Val. Loss: 0.116


100%|██████████| 18/18 [00:07<00:00,  2.56it/s]


Epoch: 83 | Epoch Time: 0m 8s
	Train Loss: 0.153
	 Val. Loss: 0.115


100%|██████████| 18/18 [00:07<00:00,  2.54it/s]


Epoch: 84 | Epoch Time: 0m 8s
	Train Loss: 0.153
	 Val. Loss: 0.115


100%|██████████| 18/18 [00:07<00:00,  2.49it/s]


Epoch: 85 | Epoch Time: 0m 8s
	Train Loss: 0.152
	 Val. Loss: 0.115


100%|██████████| 18/18 [00:07<00:00,  2.49it/s]


Epoch: 86 | Epoch Time: 0m 8s
	Train Loss: 0.152
	 Val. Loss: 0.114


100%|██████████| 18/18 [00:07<00:00,  2.50it/s]


Epoch: 87 | Epoch Time: 0m 8s
	Train Loss: 0.153
	 Val. Loss: 0.115


100%|██████████| 18/18 [00:07<00:00,  2.51it/s]


Epoch: 88 | Epoch Time: 0m 8s
	Train Loss: 0.153
	 Val. Loss: 0.115


100%|██████████| 18/18 [00:07<00:00,  2.42it/s]


Epoch: 89 | Epoch Time: 0m 8s
	Train Loss: 0.151
	 Val. Loss: 0.114


100%|██████████| 18/18 [00:07<00:00,  2.47it/s]


Epoch: 90 | Epoch Time: 0m 8s
	Train Loss: 0.150
	 Val. Loss: 0.114


100%|██████████| 18/18 [00:07<00:00,  2.56it/s]


Epoch: 91 | Epoch Time: 0m 8s
	Train Loss: 0.153
	 Val. Loss: 0.114


100%|██████████| 18/18 [00:07<00:00,  2.56it/s]


Epoch: 92 | Epoch Time: 0m 8s
	Train Loss: 0.153
	 Val. Loss: 0.114


100%|██████████| 18/18 [00:07<00:00,  2.56it/s]


Epoch: 93 | Epoch Time: 0m 8s
	Train Loss: 0.153
	 Val. Loss: 0.114


100%|██████████| 18/18 [00:07<00:00,  2.51it/s]


Epoch: 94 | Epoch Time: 0m 8s
	Train Loss: 0.152
	 Val. Loss: 0.114


100%|██████████| 18/18 [00:07<00:00,  2.55it/s]


Epoch: 95 | Epoch Time: 0m 8s
	Train Loss: 0.150
	 Val. Loss: 0.114


100%|██████████| 18/18 [00:07<00:00,  2.54it/s]


Epoch: 96 | Epoch Time: 0m 8s
	Train Loss: 0.152
	 Val. Loss: 0.113


100%|██████████| 18/18 [00:07<00:00,  2.49it/s]


Epoch: 97 | Epoch Time: 0m 8s
	Train Loss: 0.151
	 Val. Loss: 0.113


100%|██████████| 18/18 [00:06<00:00,  2.59it/s]


Epoch: 98 | Epoch Time: 0m 8s
	Train Loss: 0.149
	 Val. Loss: 0.113


100%|██████████| 18/18 [00:07<00:00,  2.50it/s]


Epoch: 99 | Epoch Time: 0m 8s
	Train Loss: 0.153
	 Val. Loss: 0.113


100%|██████████| 18/18 [00:07<00:00,  2.56it/s]


Epoch: 100 | Epoch Time: 0m 8s
	Train Loss: 0.151
	 Val. Loss: 0.113


In [10]:
#test on test dataset for either all arguments
pred,label=predict(SRL_model_french,fr_test_dataloader)
pred = [item for sublist in pred for item in sublist]
label = [item for sublist in label for item in sublist]
# modify to get a range
# get resulst only for predicate
# all_labels = list(label_to_id.values())
# exclude = [i for i in range(0,27)]
# eval_args = [label for label in all_labels if label not in exclude]
# get resulst only for arguments
eval_args=list(itemgetter('agent', 'theme', 'beneficiary', 'patient', 'topic', 'goal', 'recipient', 
                 'co-theme', 'result', 'stimulus', 'experiencer', 'destination', 'value', 
                 'attribute', 'location', 'source', 'cause', 'co-agent', 'time', 'co-patient',
                 'product', 'purpose', 'instrument', 'extent', 'asset', 'material')(label_to_id))
f1_score(label,pred,labels=eval_args,average='micro')

0.6103439403036282

In [11]:
torch.save(SRL_model_french.state_dict(), 'final_fr_hw_234_final.pt')

In [None]:
#train- english /en-es/ en-fr
#EN-80-76-76
#ES-59-70-66
#FR-60-63-67.7