In [1]:
import numpy as np
import pandas as pd
from datasets import load_dataset
import itertools

from torch import nn
from nltk import tokenize as nltk_tokenizer

dataset = load_dataset("McGill-NLP/feedbackQA")

Found cached dataset feedback_qa (/home/raja/.cache/huggingface/datasets/McGill-NLP___feedback_qa/plain_text/1.0.0/20c8f938f417c88303bb7041cea9554c1d14667686d7d7c5dda83dd4f39e5dc4)


  0%|          | 0/3 [00:00<?, ?it/s]

In [2]:
rating_scores = {'Excellent':3 , 'Acceptable':2 , 'Could be Improved':1, 'Bad': -1}

def process_df(df):
    df['question'] = df['question'].apply(lambda x: x.replace('\n',' '))
    df['answer'] = df['answer'].apply(lambda x: x.replace('\n',' '))
    df['list_feedback'] = df['feedback'].apply(lambda x: [ r + "___" + e for r,e in zip(x['rating'],x['explanation']) ])
    df['sampled_feedback'] = df['list_feedback'].apply(lambda x: np.random.choice(x).split("___") )
    df['rating_score'] = df['sampled_feedback'].apply(lambda x: rating_scores[x[0]])
    df['rating'] = df['sampled_feedback'].apply(lambda x: x[0])
    df['explanation'] = df['sampled_feedback'].apply(lambda x: x[1])
    return df

In [3]:
train_df = process_df(pd.DataFrame(dataset['train']))
val_df = process_df(pd.DataFrame(dataset['validation']))
test_df = process_df(pd.DataFrame(dataset['test']))

In [4]:
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader

# Load model from HuggingFace Hub
bert_chkpt = "sentence-transformers/all-mpnet-base-v2"
tokenizer = AutoTokenizer.from_pretrained(bert_chkpt)
model = AutoModel.from_pretrained(bert_chkpt)

In [5]:
tokenizer.all_special_tokens

['<s>', '</s>', '[UNK]', '<pad>', '<mask>']

In [6]:
train_df.head()

Unnamed: 0,question,answer,feedback,list_feedback,sampled_feedback,rating_score,rating,explanation
0,How do I get help finding a job?,Coronavirus (COVID-19) information for job see...,"{'rating': ['Excellent', 'Could be Improved'],...",[Excellent___Has a link to detailed informatio...,"[Excellent, Has a link to detailed information...",3,Excellent,Has a link to detailed information about gover...
1,How do I get help finding a job?,Coronavirus (COVID-19) information for job see...,"{'rating': ['Excellent', 'Excellent'], 'explan...",[Excellent___A link to a job search website is...,"[Excellent, Includes a link to a Jobs Hub page...",3,Excellent,"Includes a link to a Jobs Hub page, which is b..."
2,How do I get help finding a job?,Coronavirus (COVID-19) information and support...,"{'rating': ['Bad', 'Acceptable'], 'explanation...",[Bad___Talks about tax credits for businesses ...,"[Acceptable, This answer discusses the Employm...",2,Acceptable,"This answer discusses the Employment Fund, whi..."
3,If I am in Australia on a worker holiday marke...,Frequently Asked Questions Working holiday mak...,"{'rating': ['Could be Improved', 'Acceptable']...",[Could be Improved___Answer is about Working H...,"[Could be Improved, Answer is about Working Ho...",1,Could be Improved,"Answer is about Working Holiday Makers, but do..."
4,If I am in Australia on a worker holiday marke...,Frequently Asked Questions COVID-19 Pandemic -...,"{'rating': ['Bad', 'Could be Improved'], 'expl...",[Bad___Discusses pandemic visas. Doesn't menti...,"[Bad, Discusses pandemic visas. Doesn't mentio...",-1,Bad,Discusses pandemic visas. Doesn't mention the ...


In [7]:
train_df['answer'].loc[0]

'Coronavirus (COVID-19) information for job seekers Exisiting job seekers If you are a current job seeker or participant, this fact sheet provides important information about mutual obligation requirements, appointments with your provider, and what to do if you are self-isolating:  Information for job seekers and participants  If you are participating in the ParentsNext program, this fact sheet provides important information about your activities and appointments.   Information for ParentsNext participants   ParentsNext participants Frequently Asked Questions   If you are a New Business Assistance with NEIS participant, these Frequently Asked Questions (FAQ) provides information about accessing the Coronavirus Supplement and what support is available during this time:  New Business Assistance with NEIS participants - Frequently Asked Questions  If you are a New Business Assistance with NEIS provider, these Frequently Asked Questions (FAQ) provides information about supporting NEIS part

In [8]:
tokenizer('Hello, how are you doing?'+ f" {tokenizer.eos_token} " + "Hemlooooo",add_special_tokens=True,return_tensors='pt', return_length=1)

{'input_ids': tensor([[    0,  7596,  1014,  2133,  2028,  2021,  2729,  1033,     2, 19614,
          4139,  9545,  9545,     2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'length': tensor([14])}

In [9]:
len(nltk_tokenizer.sent_tokenize(train_df['answer'].loc[0]))

3

In [10]:
tok_inp = tokenizer(nltk_tokenizer.sent_tokenize(train_df['answer'].loc[0]),add_special_tokens=False,return_token_type_ids=True)#,max_length=200,padding='max_length')
tok_inp

{'input_ids': [[21891, 23354, 1010, 2526, 17262, 1015, 2543, 1011, 2596, 2009, 3109, 24075, 4658, 17421, 3440, 3109, 24075, 2069, 2021, 2028, 1041, 2787, 3109, 29448, 2034, 13184, 1014, 2027, 2759, 7127, 3644, 2594, 2596, 2059, 8207, 14991, 5922, 1014, 14655, 2011, 2119, 10806, 1014, 2002, 2058, 2004, 2083, 2069, 2021, 2028, 2973, 1015, 11167, 22252, 1028, 2596, 2009, 3109, 24075, 2002, 6822, 2069, 2021, 2028, 8023, 2003, 2000, 3012, 2642, 18417, 2569, 1014, 2027, 2759, 7127, 3644, 2594, 2596, 2059, 2119, 3454, 2002, 14655, 1016], [2596, 2009, 3012, 2642, 18417, 6822, 3012, 2642, 18417, 6822, 4707, 2360, 3984, 2069, 2021, 2028, 1041, 2051, 2453, 5379, 2011, 11269, 2487, 13184, 1014, 2126, 4707, 2360, 3984, 1010, 6908, 4164, 1011, 3644, 2596, 2059, 3233, 2079, 2000, 21891, 23354, 12452, 2002, 2058, 2494, 2007, 2804, 2080, 2027, 2055, 1028, 2051, 2453, 5379, 2011, 11269, 2487, 6822, 1015, 4707, 2360, 3984, 2069, 2021, 2028, 1041, 2051, 2453, 5379, 2011, 11269, 2487, 10806, 1014, 2126, 47

In [11]:
tokenizer('Hello, how are you?',return_special_tokens_mask=True,add_special_tokens=True, padding='max_length', max_length=20)['special_tokens_mask']

[1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]

In [12]:
import tqdm

class feedback_QA_dataset(Dataset):
    
    def __init__(self,df,max_length=512):
        self.df = df
        self.max_len = max_length
        self.data = []
        skipped = 0
        
        for i in tqdm.tqdm(range(len(self.df)),desc='vectorizing..'):
            
            d = {}
            
            tok_question = tokenizer('Question: ' + self.df.iloc[i]['question'] + ' Answer: ', add_special_tokens=False)
            tok_answer = tokenizer(self.df.iloc[i]['answer'], add_special_tokens=False)
            tok_feedback = tokenizer(self.df.iloc[i]['explanation'], add_special_tokens=False)
            
            if len(tok_question['input_ids']+tok_answer['input_ids']+tok_feedback['input_ids'])+4 > self.max_len:
                skipped +=1
                continue
            
            context = [tokenizer.bos_token_id] + tok_question['input_ids'] + tok_answer['input_ids']
            context_attn = [1] + tok_question['attention_mask'] + tok_answer['attention_mask']
            context_pool_mask = [0] + [0]*len(tok_question['input_ids']) + tok_answer['attention_mask']
            
            
            d['context_w_feedback'] = context + [tokenizer.sep_token_id]*2 + tok_feedback['input_ids'] + [tokenizer.eos_token_id]
            
            PAD_LEN = self.max_len - len(d['context_w_feedback'])
            
            d['context_w_feedback'] += [tokenizer.pad_token_id]*PAD_LEN
            d['context_w_feedback_attn'] = context_attn + [1,1] + tok_feedback['attention_mask'] + [1] + [0]*PAD_LEN            
            d['context'] = d['context_w_feedback']
            d['context_attn'] = context_attn + [1,0] + [0]*len(tok_feedback['attention_mask']) + [0] + [0]*PAD_LEN
            
            d['feedback_pool_mask'] = [0]*len(context_pool_mask) + [0,0] + tok_feedback['attention_mask'] + [0] + [0]*PAD_LEN
            d['answer_pool_mask'] = context_pool_mask + [0,0] + [0]*len(tok_feedback['attention_mask']) + [0] + [0]*PAD_LEN
            
            answer_phrases = nltk_tokenizer.sent_tokenize(self.df.iloc[i]['answer'])
            tok_phrases = tokenizer(answer_phrases,add_special_tokens=False,return_token_type_ids=True)
            
            d['answer_phrases_pool_mask'] = []
            
            for j in range(len(answer_phrases)):
                answer_phrases_attn_mask = tok_phrases['token_type_ids'].copy()
                answer_phrases_attn_mask[j] = tok_phrases['attention_mask'][j].copy()
                answer_phrases_attn_mask = list(itertools.chain.from_iterable(answer_phrases_attn_mask))
                # pad_len = len(tok_answer['attention_mask']) - len(answer_phrases_attn_mask)
                # answer_phrases_attn_mask += [0]*pad_len
                
                answer_phrase_pool_mask = [0] + [0]*len(tok_question['input_ids']) + answer_phrases_attn_mask + [0,0] + [0]*len(tok_feedback['attention_mask']) + [0] + [0]*PAD_LEN
                
                d['answer_phrases_pool_mask'].append(answer_phrase_pool_mask)
            
            if len(d['answer_phrases_pool_mask'][0])>len(d['answer_pool_mask']):
                skipped +=1
                continue
                
            self.data.append(d)
        print('Skipped: ',skipped)

    def add_neg_samples(self):
        for i in tqdm.tqdm(range(self.__len__()),desc='adding neg samples...'):
            self.data[i]['feedback_set'] = [self.data[i]['context_w_feedback']]
            self.data[i]['feedback_attn_set'] = [self.data[i]['context_w_feedback_attn']]
            self.data[i]['feedback_pool_mask_set'] = [self.data[i]['feedback_pool_mask']]
            L = list(range(self.__len__()))
            L.remove(i)
            neg_samples_idx = np.random.choice(L,size=4)
            for n_id in neg_samples_idx:
                self.data[i]['feedback_set'].append(self.data[n_id]['context_w_feedback'])
                self.data[i]['feedback_attn_set'].append(self.data[n_id]['context_w_feedback_attn'])
                self.data[i]['feedback_pool_mask_set'].append(self.data[n_id]['feedback_pool_mask'])
            for k in self.data[i].keys():
                self.data[i][k] = torch.tensor(self.data[i][k])
                
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,idx):
        return self.data[idx]

In [13]:
# train_dataset = feedback_QA_dataset(train_df)
# train_dataset.add_neg_samples()
# valid_dataset = feedback_QA_dataset(val_df)
# valid_dataset.add_neg_samples()
test_dataset = feedback_QA_dataset(test_df)
test_dataset.add_neg_samples()

vectorizing..:   0%|                                                                                                                           | 0/1995 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1323 > 512). Running this sequence through the model will result in indexing errors
vectorizing..: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1995/1995 [00:06<00:00, 325.52it/s]


Skipped:  173


adding neg samples...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 1822/1822 [00:08<00:00, 224.95it/s]


In [14]:
# train_DL = DataLoader(train_dataset,batch_size=1,shuffle=True)
# valid_DL = DataLoader(valid_dataset,batch_size=1,shuffle=True)
test_DL = DataLoader(test_dataset,batch_size=1,shuffle=False)

In [15]:
for b in test_DL:
    for k in b.keys():
        print(k,b[k].shape)
    break

context_w_feedback torch.Size([1, 512])
context_w_feedback_attn torch.Size([1, 512])
context torch.Size([1, 512])
context_attn torch.Size([1, 512])
feedback_pool_mask torch.Size([1, 512])
answer_pool_mask torch.Size([1, 512])
answer_phrases_pool_mask torch.Size([1, 2, 512])
feedback_set torch.Size([1, 5, 512])
feedback_attn_set torch.Size([1, 5, 512])
feedback_pool_mask_set torch.Size([1, 5, 512])


In [70]:
from transformers import BartForConditionalGeneration

device = 'cuda:0'

model = AutoModel.from_pretrained(bert_chkpt).to(device)
# model.load_state_dict(torch.load('Detect_Span_FB_MPNET_chkpts_1/best_model_chkpt.pth.tar')['model_state'])

In [65]:
chkpt = torch.load('Detect_Span_FB_MPNET_chkpts_3/Epoch_0_model_chkpt.pth.tar')
from collections import OrderedDict
model_state = OrderedDict([(k[6:],v) for k,v in chkpt['model_state'].items()])    

In [66]:
model.load_state_dict(model_state)

<All keys matched successfully>

In [69]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    se = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return F.normalize(se, p=2, dim=1)

j = 0

with torch.no_grad():
    model.eval()
    for b in test_DL:
        se = mean_pooling( model(input_ids = b['context_w_feedback'].to(device),attention_mask=b['context_attn'].to(device))[0], b['answer_pool_mask'].to(device))
        
        fmo = model(input_ids = b['feedback_set'][0].to(device),attention_mask=b['feedback_attn_set'][0].to(device))[0]
        # print(b['feedback_set'][0].shape, b['feedback_attn_set'][0].shape, fmo.shape, b['feedback_pool_mask_set'][0].shape)
        
        fe = mean_pooling(model(input_ids = b['feedback_set'][0].to(device),attention_mask=b['feedback_attn_set'][0].to(device))[0], b['feedback_pool_mask_set'][0].to(device))
        
        pmo = model(input_ids = b['context_w_feedback'].to(device),attention_mask=b['context_attn'].to(device))
        # print(pmo[0].shape,b['answer_phrases_pool_mask'].shape,pmo[0].repeat(b['answer_phrases_pool_mask'][0].shape[0],1,1).shape,b['answer_phrases_pool_mask'][0].shape)
        pe = mean_pooling(pmo[0].repeat(b['answer_phrases_pool_mask'][0].shape[0],1,1),b['answer_phrases_pool_mask'][0].to(device) )# for i in range(b['answer_phrases_pool_mask'][0].shape[0])]
        # pe = torch.stack(pe).squeeze(1)
        cos_sim = F.cosine_similarity(se,fe,dim=1)
        cos_phrase_sim = torch.matmul(pe,fe.transpose(1,0))
        # print(fe.shape,se.shape,pe.shape,cos_sim,cos_phrase_sim.mean(0))
        
        sent_probs = F.softmax(cos_sim,dim=-1)
        phrase_probs = F.softmax(cos_phrase_sim,dim=-1)
        
        print('\nInput: ',tokenizer.decode(torch.mul(b['context_w_feedback'][0],b['context_attn'][0]),skip_special_tokens=True),'\n')
        print('Feedback: ',tokenizer.decode(torch.mul(b['context_w_feedback'][0],b['feedback_pool_mask'][0]),skip_special_tokens=True),'\n')
        for i in range(b['answer_phrases_pool_mask'][0].shape[0]):
            relevance = phrase_probs[i][0] - sent_probs[0]
            
            phrase_tok = torch.mul(b['context_w_feedback'][0],b['answer_phrases_pool_mask'][0][i])
            print(f"Phrase {i}:",tokenizer.decode(phrase_tok,skip_special_tokens=True))
            print(f"Relevance of phrase {i} is {relevance}",'\n')
        
#         print('softmax: ',F.softmax(cos_sim),F.softmax(cos_phrase_sim,dim=-1))
        
#         tgt_tensor = torch.zeros(b['feedback_set'].shape[1] , device=device)
#         tgt_tensor[0] = 1.0
#         print('CE Loss: ', F.cross_entropy(cos_sim,target=tgt_tensor), F.cross_entropy(cos_phrase_sim.mean(0),target=torch.tensor([1.0,0,0,0,0]).to(device)))
        print('----------------------------')
        j+=1
        if j>5:
            break

del model


Input:  question : what are my options if i can not support myself on a whm visa? answer : frequently asked questions covid - 19 pandemic - australian government endorsed event ( agee ) stream of the temporary activity ( subclass 408 ) visa frequently asked questions when can i apply for the covid - 19 pandemic event visa? you should only apply for this visa is you are unable to depart australia, your temporary visa expires in less than 28 days ( or did not expire more than 28 days ago ) and you have no other visa options available to you. 

Feedback:  this is off task, the answer is more about an expired visa 

Phrase 0: frequently asked questions covid - 19 pandemic - australian government endorsed event ( agee ) stream of the temporary activity ( subclass 408 ) visa frequently asked questions when can i apply for the covid - 19 pandemic event visa?
Relevance of phrase 0 is -0.0016022920608520508 

Phrase 1: you should only apply for this visa is you are unable to depart australia, 

## Evaluating over test set

In [71]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    se = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return F.normalize(se, p=2, dim=1)

j = 0

with torch.no_grad():
    model.eval()
    with open('baseline_test_preds.txt','w') as f:
        for b in tqdm.tqdm(test_DL,desc='evaluating...'):
            se = mean_pooling( model(input_ids = b['context_w_feedback'].to(device),attention_mask=b['context_attn'].to(device))[0], b['answer_pool_mask'].to(device))

            fmo = model(input_ids = b['feedback_set'][0].to(device),attention_mask=b['feedback_attn_set'][0].to(device))[0]
            # print(b['feedback_set'][0].shape, b['feedback_attn_set'][0].shape, fmo.shape, b['feedback_pool_mask_set'][0].shape)

            fe = mean_pooling(model(input_ids = b['feedback_set'][0].to(device),attention_mask=b['feedback_attn_set'][0].to(device))[0], b['feedback_pool_mask_set'][0].to(device))

            pmo = model(input_ids = b['context_w_feedback'].to(device),attention_mask=b['context_attn'].to(device))
            # print(pmo[0].shape,b['answer_phrases_pool_mask'].shape,pmo[0].repeat(b['answer_phrases_pool_mask'][0].shape[0],1,1).shape,b['answer_phrases_pool_mask'][0].shape)
            pe = mean_pooling(pmo[0].repeat(b['answer_phrases_pool_mask'][0].shape[0],1,1),b['answer_phrases_pool_mask'][0].to(device) )# for i in range(b['answer_phrases_pool_mask'][0].shape[0])]
            # pe = torch.stack(pe).squeeze(1)
            cos_sim = F.cosine_similarity(se,fe,dim=1)
            cos_phrase_sim = torch.matmul(pe,fe.transpose(1,0))
            # print(fe.shape,se.shape,pe.shape,cos_sim,cos_phrase_sim.mean(0))

            sent_probs = F.softmax(cos_sim,dim=-1)
            phrase_probs = F.softmax(cos_phrase_sim,dim=-1)

            f.write(f"\nInput: {tokenizer.decode(torch.mul(b['context_w_feedback'][0],b['context_attn'][0]),skip_special_tokens=True)}\n\n")
            f.write(f"Feedback: {tokenizer.decode(torch.mul(b['context_w_feedback'][0],b['feedback_pool_mask'][0]),skip_special_tokens=True)}\n\n")
            for i in range(b['answer_phrases_pool_mask'][0].shape[0]):
                relevance = phrase_probs[i][0] - sent_probs[0]

                phrase_tok = torch.mul(b['context_w_feedback'][0],b['answer_phrases_pool_mask'][0][i])
                f.write(f"Phrase {i}: {tokenizer.decode(phrase_tok,skip_special_tokens=True)}\n")
                f.write(f"Relevance of phrase {i} is {relevance}\n\n")

    #         print('softmax: ',F.softmax(cos_sim),F.softmax(cos_phrase_sim,dim=-1))

    #         tgt_tensor = torch.zeros(b['feedback_set'].shape[1] , device=device)
    #         tgt_tensor[0] = 1.0
    #         print('CE Loss: ', F.cross_entropy(cos_sim,target=tgt_tensor), F.cross_entropy(cos_phrase_sim.mean(0),target=torch.tensor([1.0,0,0,0,0]).to(device)))
            f.write('----------------------------')
            # j+=1
            # if j>5:
            #     break

del model

evaluating...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1822/1822 [03:43<00:00,  8.16it/s]


In [230]:
t = torch.tensor([[[1,2,3,4,5],[6,7,8,9,0]]])
t.repeat(2,1,1)

tensor([[[1, 2, 3, 4, 5],
         [6, 7, 8, 9, 0]],

        [[1, 2, 3, 4, 5],
         [6, 7, 8, 9, 0]]])

In [82]:
class discriminator(nn.Module):
    def __init__(self, model_chkpt, device='cuda:0'):
        super().__init__()
        
        self.model = AutoModel.from_pretrained(model_chkpt).to(device)
        self.device = device
        
    def mean_pooling(self,model_output,attention_mask):
        token_embeddings = model_output #First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        se = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return F.normalize(se, p=2, dim=1)
        
    def forward(self, b):
        sent_model_out = self.model(input_ids = b['context_w_feedback'].to(self.device),attention_mask=b['context_attn'].to(self.device))[0]
        feedback_model_out = self.model(input_ids = b['feedback_set'][0].to(self.device),attention_mask=b['feedback_attn_set'][0].to(self.device))[0]
        
        sent_emb = self.mean_pooling( sent_model_out, b['answer_pool_mask'].to(self.device))
        feedback_emb = self.mean_pooling( feedback_model_out, b['feedback_pool_mask_set'][0].to(self.device))
        
        # print(pmo[0].shape,b['answer_phrases_pool_mask'].shape)
        phrase_emb = self.mean_pooling( sent_model_out.repeat(b['answer_phrases_pool_mask'][0].shape[0],1,1), b['answer_phrases_pool_mask'][0].to(self.device) )
        # phrase_emb = torch.stack(phrase_emb).squeeze(1)
        cos_sim = F.cosine_similarity(sent_emb,feedback_emb,dim=1)
        cos_phrase_sim = torch.matmul(phrase_emb,feedback_emb.transpose(1,0))
        
        tgt_tensor = torch.zeros(b['feedback_set'].shape[1] , device=self.device)
        tgt_tensor[0] = 1.0 #the relevant feedback is always present at index 0
        
        return_dict = {'sent_ce_loss': F.cross_entropy(cos_sim,target=tgt_tensor),
                       'avg_phrase_ce_loss': F.cross_entropy(cos_phrase_sim.mean(0),target=tgt_tensor),
                       'sent_probs': F.softmax(cos_sim,dim=-1),
                       'phrase_probs': F.softmax(cos_phrase_sim,dim=-1)}
        
        return return_dict
        
        

In [83]:
def train(discriminator,train_dl,valid_dl,epochs,batch_size,optimizer,PATIENCE=20,save_dir=None):

    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    
    discriminator.train()
    
    loss_acc = 0
    num_batches = 0
    total_steps = 0
    best_valid_loss = np.inf
    patience = PATIENCE
    
    train_loss_arr,valid_loss_arr = [],[]
    
    optimizer.zero_grad()
    discriminator.zero_grad()
    
    for E in range(epochs):
        
        valid_loss = validate(discriminator,valid_dl)
        valid_loss_arr.append(valid_loss/len(valid_dl))
        
        num_samples = 0
        
        for b in train_dl:
            
            y = discriminator(b)
                          # decoder_input_ids=b['feedback'].squeeze(1)[:,:-1].to(device),
                          # decoder_attention_mask=b['feedback_attn'].squeeze(1)[:,:-1].to(device))
            loss = y['sent_ce_loss'] + y['avg_phrase_ce_loss'] #F.cross_entropy(y.logits.permute(0,2,1), b['feedback'].squeeze(1)[:,1:].to(device), ignore_index=tokenizer.pad_token_id)
            
            num_samples+=1
            
            loss.backward()
            loss_acc += loss.item()
            
            if num_samples%batch_size==0:
                optimizer.step()

                num_batches += 1
                total_steps += 1
            
                train_loss_arr.append(loss_acc/num_batches)
                
                optimizer.zero_grad()
            
                if total_steps%100==0 and total_steps!=0:
                    print("Epoch:",E,"\t","Steps taken:",total_steps,"\tLoss:",loss_acc/num_batches)
            
        #print("Epoch:",E,"\t","Steps taken:",total_steps,"\tLoss:",loss_acc/num_batches)
        
        torch.save({'model_state':discriminator.state_dict(),
                    'optimizer':optimizer.state_dict(),
                    'epoch':E},
                    f"{save_dir}/Epoch_{E}_model_chkpt.pth.tar")
        
        if valid_loss<best_valid_loss:
            best_valid_loss = valid_loss
            patience = PATIENCE
            
            torch.save({'model_state':discriminator.state_dict(),
                        'optimizer':optimizer.state_dict(),
                        'epoch':E},
                        f"{save_dir}/best_model_chkpt.pth.tar")
        else:
            patience -= 1
            print(f"REDUCING PATIENCE...{patience}")

        if patience<=0:
            print("RUNNING OUT OF PATIENCE... TERMINATING")
            break
    
    
    return train_loss_arr,valid_loss_arr
                

In [84]:
def validate(discriminator,valid_dl):
    
    discriminator.eval()
    valid_loss = 0
    with torch.no_grad():
        for b in valid_dl:
            y = discriminator(b)
                          # decoder_input_ids=b['feedback'].squeeze(1)[:,:-1].to(device),
                          # decoder_attention_mask=b['feedback_attn'].squeeze(1)[:,:-1].to(device))
            loss = y['sent_ce_loss'] + y['avg_phrase_ce_loss'] #F.cross_entropy(y.logits.permute(0,2,1), b['feedback'].squeeze(1)[:,1:].to(device), ignore_index=tokenizer.pad_token_id)
            valid_loss += loss.item()
            
    print("Validation Loss:",valid_loss)
    return valid_loss

In [None]:
import os

from transformers import AutoModel

EPOCHS = 50
BATCH_SIZE = 16

device = 'cuda:1'

# MPNet = AutoModel.from_pretrained(bert_chkpt).to(device)
discriminator_model = discriminator(bert_chkpt,device=device)

optimizer = torch.optim.AdamW(discriminator_model.parameters(),lr=1e-5)

save_dir = 'Detect_Span_FB_MPNET_chkpts_3'
if not os.path.exists(save_dir):
    os.mkdir(save_dir)

train_loss,valid_loss = train(discriminator_model,
                              train_DL,
                              valid_DL,
                              EPOCHS,
                              BATCH_SIZE,
                              optimizer,
                              PATIENCE=5,
                              save_dir=save_dir)

Validation Loss: 3223.12977540493
Epoch: 0 	 Steps taken: 100 	Loss: 29.962738287448882
Epoch: 0 	 Steps taken: 200 	Loss: 29.572392408251762
Epoch: 0 	 Steps taken: 300 	Loss: 29.432847184737522
Validation Loss: 2379.9055646657944
Epoch: 1 	 Steps taken: 400 	Loss: 29.339444583654405


In [None]:
import json

with open('train_loss.json','w') as f:
    json.dump(train_loss,f)

with open('valid_loss.json','w') as f:
    json.dump(valid_loss,f)

In [None]:
train_loss_ds = np.array(train_loss)[np.round(np.linspace(0, len(train_loss) - 1, len(valid_loss))).astype(int)]
loss_df = pd.DataFrame({'train_loss':train_loss_ds , 'valid_loss':valid_loss})

In [None]:
from plotly import express as px
px.line(loss_df,y=['train_loss','valid_loss'])

In [None]:
discriminator.load_state_dict(torch.load('GenFB_BART_chkpts_1/Epoch_0_model_chkpt.pth.tar')['model_state'])

In [None]:
i = 0
for b in train_DL:
    out = discriminator.generate(inputs=b['input'][0:1,0].to(device),top_p=0.5)
    print(tokenizer.decode(b['input'][0:1,0][0],skip_special_tokens=True))
    print(tokenizer.decode(b['feedback'][0:1,0][0],skip_special_tokens=True))
    print(tokenizer.decode(out[0]))
    print("--------------------------------------------------------")
    i+=1
    if i>10:
        break