In [None]:
import numpy as np
import pandas as pd
from datasets import load_dataset
import itertools

from torch import nn
from nltk import tokenize as nltk_tokenizer

dataset = load_dataset("McGill-NLP/feedbackQA")

In [None]:
rating_scores = {'Excellent':3 , 'Acceptable':2 , 'Could be Improved':1, 'Bad': -1}

def process_df(df):
    df['question'] = df['question'].apply(lambda x: x.replace('\n',' '))
    df['answer'] = df['answer'].apply(lambda x: x.replace('\n',' '))
    df['list_feedback'] = df['feedback'].apply(lambda x: [ r + "___" + e for r,e in zip(x['rating'],x['explanation']) ])
    df['sampled_feedback'] = df['list_feedback'].apply(lambda x: np.random.choice(x).split("___") )
    df['rating_score'] = df['sampled_feedback'].apply(lambda x: rating_scores[x[0]])
    df['rating'] = df['sampled_feedback'].apply(lambda x: x[0])
    df['explanation'] = df['sampled_feedback'].apply(lambda x: x[1])
    return df

In [None]:
train_df = process_df(pd.DataFrame(dataset['train']))
val_df = process_df(pd.DataFrame(dataset['validation']))
test_df = process_df(pd.DataFrame(dataset['test']))

In [None]:
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader

# Load model from HuggingFace Hub
bert_chkpt = "meta-llama/Llama-2-7b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(bert_chkpt)
# model = AutoModel.from_pretrained(bert_chkpt)

In [None]:
tokenizer.all_special_tokens

In [None]:
train_df.head()

In [None]:
train_df['answer'].loc[0]

In [None]:
tokenizer('Hello, how are you doing?'+ f" {tokenizer.eos_token} " + "Hemlooooo",add_special_tokens=True,return_tensors='pt', return_length=1)

In [None]:
len(nltk_tokenizer.sent_tokenize(train_df['answer'].loc[0]))

In [None]:
tok_inp = tokenizer(nltk_tokenizer.sent_tokenize(train_df['answer'].loc[0]),add_special_tokens=False,return_token_type_ids=True)#,max_length=200,padding='max_length')
tok_inp

In [None]:
tokenizer.pad_token_id = 0
tokenizer('Hello, how are you?',return_special_tokens_mask=True,add_special_tokens=True, padding='max_length', max_length=20)

In [None]:
import tqdm

class feedback_QA_dataset(Dataset):
    
    def __init__(self,df,max_length=512):
        self.df = df
        self.max_len = max_length
        self.data = []
        skipped = 0
        
        for i in tqdm.tqdm(range(len(self.df)),desc='vectorizing..'):
            
            d = {'id':i}
            
            tok_question = tokenizer('Question: ' + self.df.iloc[i]['question'] + ' Answer: ', add_special_tokens=False)
            tok_answer = tokenizer(self.df.iloc[i]['answer'].strip().replace('  ',' '), add_special_tokens=False)
            tok_feedback = tokenizer(self.df.iloc[i]['explanation'], add_special_tokens=False)

            d['question'] = tok_question['input_ids']
            d['answer'] = tok_answer['input_ids']
            d['feedback'] = tok_feedback['input_ids']
            
            if len(tok_question['input_ids']+tok_answer['input_ids']+tok_feedback['input_ids'])+4 > self.max_len:
                skipped +=1
                continue
            
            context = [tokenizer.bos_token_id] + tok_question['input_ids'] + tok_answer['input_ids']
            context_attn = [1] + tok_question['attention_mask'] + tok_answer['attention_mask']
            context_pool_mask = [0] + [0]*len(tok_question['input_ids']) + tok_answer['attention_mask']
            
            
            d['context_w_feedback'] = context + [tokenizer.eos_token_id] + tok_feedback['input_ids'] + [tokenizer.eos_token_id]
            
            PAD_LEN = self.max_len - len(d['context_w_feedback'])

            d['Input_len'] = len(d['context_w_feedback'])
            d['PAD_LEN'] = PAD_LEN
            
            d['context_w_feedback'] += [tokenizer.pad_token_id]*PAD_LEN
            d['context_w_feedback_attn'] = context_attn + [1] + tok_feedback['attention_mask'] + [1] + [0]*PAD_LEN            
            d['context'] = d['context_w_feedback']
            d['context_attn'] = context_attn + [1] + [0]*len(tok_feedback['attention_mask']) + [0] + [0]*PAD_LEN
            
            d['feedback_pool_mask'] = [0]*len(context_pool_mask) + [0] + tok_feedback['attention_mask'] + [0] + [0]*PAD_LEN
            d['answer_pool_mask'] = context_pool_mask + [0] + [0]*len(tok_feedback['attention_mask']) + [0] + [0]*PAD_LEN
            
            answer_phrases = nltk_tokenizer.sent_tokenize(self.df.iloc[i]['answer'].strip().replace('  ',' '))
            tok_phrases = tokenizer(answer_phrases,add_special_tokens=False,return_token_type_ids=True)

            d['tok_phrases'] = tok_phrases['input_ids']
            d['answer_phrases_pool_mask'] = []
            
            for j in range(len(answer_phrases)):
                answer_phrases_attn_mask = tok_phrases['token_type_ids'].copy()
                answer_phrases_attn_mask[j] = tok_phrases['attention_mask'][j].copy()
                answer_phrases_attn_mask = list(itertools.chain.from_iterable(answer_phrases_attn_mask))
                pad_len = len(tok_answer['attention_mask']) - len(answer_phrases_attn_mask)
                answer_phrases_attn_mask += [0]*pad_len
                
                answer_phrase_pool_mask = [0] + [0]*len(tok_question['input_ids']) + answer_phrases_attn_mask + [0] + [0]*len(tok_feedback['attention_mask']) + [0] + [0]*PAD_LEN
                
                d['answer_phrases_pool_mask'].append(answer_phrase_pool_mask)
            
            if len(d['answer_phrases_pool_mask'][0])>len(d['answer_pool_mask']):
                skipped +=1
                continue
                
            else:
                self.data.append(d)
                
        print('Skipped: ',skipped)

    def add_neg_samples(self):
        for i in tqdm.tqdm(range(self.__len__()),desc='adding neg samples...'):
            self.data[i]['feedback_set'] = [self.data[i]['context_w_feedback']]
            self.data[i]['feedback_attn_set'] = [self.data[i]['context_w_feedback_attn']]
            self.data[i]['feedback_pool_mask_set'] = [self.data[i]['feedback_pool_mask']]
            L = list(range(self.__len__()))
            L.remove(i)
            neg_samples_idx = np.random.choice(L,size=4)
            for n_id in neg_samples_idx:
                self.data[i]['feedback_set'].append(self.data[n_id]['context_w_feedback'])
                self.data[i]['feedback_attn_set'].append(self.data[n_id]['context_w_feedback_attn'])
                self.data[i]['feedback_pool_mask_set'].append(self.data[n_id]['feedback_pool_mask'])
            for k in self.data[i].keys():
                if k!='tok_phrases':
                    self.data[i][k] = torch.tensor(self.data[i][k])
                
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,idx):
        return self.data[idx]

In [None]:
train_dataset = feedback_QA_dataset(train_df)
train_dataset.add_neg_samples()
valid_dataset = feedback_QA_dataset(val_df)
valid_dataset.add_neg_samples()
test_dataset = feedback_QA_dataset(test_df)
test_dataset.add_neg_samples()

In [None]:
train_DL = DataLoader(train_dataset,batch_size=1,shuffle=True)
valid_DL = DataLoader(valid_dataset,batch_size=1,shuffle=True)
test_DL = DataLoader(test_dataset,batch_size=1,shuffle=False)

In [None]:
for b in train_DL:
    for k in b.keys():
        print(k)
        print(b[k].shape)
    break

In [None]:
t = torch.tensor([[[1,2,3],[4,5,6]]])
t.shape,t.repeat(2,1,1).shape

In [None]:
from transformers import BartForConditionalGeneration

device = 'cuda:1'

model = AutoModel.from_pretrained(bert_chkpt,cache_dir='/home/jupyter/Ravi_new/HF_cache').to(device)

def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    se = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return F.normalize(se, p=2, dim=1)

j = 0

with torch.no_grad():
    for b in test_DL:
        se = mean_pooling( model(input_ids = b['context_w_feedback'].to(device),attention_mask=b['context_attn'].to(device))[0], b['answer_pool_mask'].to(device))
        
        fmo = model(input_ids = b['feedback_set'][0].to(device),attention_mask=b['feedback_attn_set'][0].to(device))[0]
        # print(b['feedback_set'][0].shape, b['feedback_attn_set'][0].shape, fmo.shape, b['feedback_pool_mask_set'][0].shape)
        
        fe = mean_pooling(model(input_ids = b['feedback_set'][0].to(device),attention_mask=b['feedback_attn_set'][0].to(device))[0], b['feedback_pool_mask_set'][0].to(device))
        
        pmo = model(input_ids = b['context_w_feedback'].to(device),attention_mask=b['context_attn'].to(device))
        # print(pmo[0].shape,b['answer_phrases_pool_mask'].shape,pmo[0].repeat(b['answer_phrases_pool_mask'][0].shape[0],1,1).shape,b['answer_phrases_pool_mask'][0].shape)
        # print(b['question'].shape,b['answer'].shape,b['feedback'].shape,b['PAD_LEN'],b['answer_pool_mask'].shape,b['id'])
        pe = mean_pooling(pmo[0].repeat(b['answer_phrases_pool_mask'][0].shape[0],1,1),b['answer_phrases_pool_mask'][0].to(device) )# for i in range(b['answer_phrases_pool_mask'][0].shape[0])]
        # pe = torch.stack(pe).squeeze(1)
        cos_sim = F.cosine_similarity(se,fe,dim=1)
        cos_phrase_sim = torch.matmul(pe,fe.transpose(1,0))
        print(fe.shape,se.shape,pe.shape,cos_sim,cos_phrase_sim.mean(0))
        
        sent_probs = F.softmax(cos_sim,dim=-1)
        phrase_probs = F.softmax(cos_phrase_sim,dim=-1)
        
        print('\nInput: ',tokenizer.decode(torch.mul(b['context_w_feedback'][0],b['context_attn'][0]),skip_special_tokens=True),'\n')
        print('Feedback: ',tokenizer.decode(torch.mul(b['context_w_feedback'][0],b['feedback_pool_mask'][0]),skip_special_tokens=True),'\n')
        for i in range(b['answer_phrases_pool_mask'][0].shape[0]):
            relevance = phrase_probs[i][0] - sent_probs[0]
            
            phrase_tok = torch.mul(b['context_w_feedback'][0],b['answer_phrases_pool_mask'][0][i])
            print(f"Phrase {i}:",tokenizer.decode(phrase_tok,skip_special_tokens=True))
            print(f"Relevance of phrase {i} is {relevance}",'\n')
        
        print('softmax: ',F.softmax(cos_sim),F.softmax(cos_phrase_sim,dim=-1))
        
        tgt_tensor = torch.zeros(b['feedback_set'].shape[1] , device=device)
        tgt_tensor[0] = 1.0
        print('CE Loss: ', F.cross_entropy(cos_sim,target=tgt_tensor), F.cross_entropy(cos_phrase_sim.mean(0),target=torch.tensor([1.0,0,0,0,0]).to(device)))
        print('----------------------------')
        j+=1
        if j>15:
            break

del model

In [None]:
t = torch.tensor([[[1,2,3,4,5],[6,7,8,9,0]]])
t.repeat(2,1,1)

In [None]:
class discriminator(nn.Module):
    def __init__(self, model_chkpt, device='cpu'):
        super().__init__()
        
        self.model = AutoModel.from_pretrained(bert_chkpt,cache_dir='/home/jupyter/Ravi_new/HF_cache')
        self.device = device
        
    def mean_pooling(self,model_output,attention_mask):
        token_embeddings = model_output #First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        se = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return F.normalize(se, p=2, dim=1)
        
    def forward(self, b):
        sent_model_out = self.model(input_ids = b['context_w_feedback'].to(self.device),attention_mask=b['context_attn'].to(self.device))[0]
        feedback_model_out = self.model(input_ids = b['feedback_set'][0].to(self.device),attention_mask=b['feedback_attn_set'][0].to(self.device))[0]
        
        sent_emb = self.mean_pooling( sent_model_out, b['answer_pool_mask'].to(self.device))
        feedback_emb = self.mean_pooling( feedback_model_out, b['feedback_pool_mask_set'][0].to(self.device))
        
        # print(pmo[0].shape,b['answer_phrases_pool_mask'].shape)
        phrase_emb = self.mean_pooling( sent_model_out.repeat(b['answer_phrases_pool_mask'][0].shape[0],1,1), b['answer_phrases_pool_mask'][0].to(self.device) )
        # phrase_emb = torch.stack(phrase_emb).squeeze(1)
        cos_sim = F.cosine_similarity(sent_emb,feedback_emb,dim=1)
        cos_phrase_sim = torch.matmul(phrase_emb,feedback_emb.transpose(1,0))
        
        tgt_tensor = torch.zeros(b['feedback_set'].shape[1] , device=self.device)
        tgt_tensor[0] = 1.0 #the relevant feedback is always present at index 0
        
        return_dict = {'sent_ce_loss': F.cross_entropy(cos_sim,target=tgt_tensor),
                       'avg_phrase_ce_loss': F.cross_entropy(cos_phrase_sim.mean(0),target=tgt_tensor),
                       'sent_probs': F.softmax(cos_sim,dim=-1),
                       'phrase_probs': F.softmax(cos_phrase_sim,dim=-1)}
        
        return return_dict
        
        

In [None]:
from accelerate import Accelerator, notebook_launcher

In [None]:
def train(discriminator,train_dl,valid_dl,epochs,batch_size,optimizer,PATIENCE=20,save_dir=None):

    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    accelerator = Accelerator()
    discriminator.device = accelerator.device
    
    discriminator, train_dl, valid_dl, optimizer = accelerator.prepare(discriminator, train_dl, valid_dl, optimizer)
    

    def validate(discriminator,valid_dl):
    
        discriminator.eval()
        valid_loss = 0
        with torch.no_grad():
            for b in valid_dl:
                y = discriminator(b)
                              # decoder_input_ids=b['feedback'].squeeze(1)[:,:-1].to(device),
                              # decoder_attention_mask=b['feedback_attn'].squeeze(1)[:,:-1].to(device))
                loss = y['sent_ce_loss'] + y['avg_phrase_ce_loss'] #F.cross_entropy(y.logits.permute(0,2,1), b['feedback'].squeeze(1)[:,1:].to(device), ignore_index=tokenizer.pad_token_id)
                valid_loss += loss.item()
                
        accelerator.print("Validation Loss:",valid_loss)
        return valid_loss
    
    discriminator.train()
    
    loss_acc = 0
    num_batches = 0
    total_steps = 0
    
    patience = PATIENCE
    
    train_loss_arr,valid_loss_arr = [],[]
    
    optimizer.zero_grad()
    discriminator.zero_grad()
    
    valid_loss = validate(discriminator,valid_dl)
    valid_loss_arr.append(valid_loss/len(valid_dl))
    best_valid_loss = valid_loss
    
    for E in range(epochs):
        
        num_samples = 0
        
        for b in train_dl:
            
            y = discriminator(b)
                          # decoder_input_ids=b['feedback'].squeeze(1)[:,:-1].to(device),
                          # decoder_attention_mask=b['feedback_attn'].squeeze(1)[:,:-1].to(device))
            loss = y['sent_ce_loss'] + y['avg_phrase_ce_loss'] #F.cross_entropy(y.logits.permute(0,2,1), b['feedback'].squeeze(1)[:,1:].to(device), ignore_index=tokenizer.pad_token_id)
            
            num_samples+=1
            
            accelerator.backward(loss)
            loss_acc += loss.item()
            
            if num_samples%batch_size==0:
                optimizer.step()

                num_batches += 1
                total_steps += 1
            
                train_loss_arr.append(loss_acc/num_batches)
                
                optimizer.zero_grad()
            
                if total_steps%100==0 and total_steps!=0:
                    accelerator.print("Epoch:",E,"\t","Steps taken:",total_steps,"\tLoss:",loss_acc/num_batches)
            
        #print("Epoch:",E,"\t","Steps taken:",total_steps,"\tLoss:",loss_acc/num_batches)
        
        # torch.save({'model_state':discriminator.state_dict(),
        #             'optimizer':optimizer.state_dict(),
        #             'epoch':E},
        #             f"{save_dir}/Epoch_{E}_model_chkpt.pth.tar")
        
        valid_loss = validate(discriminator,valid_dl)
        valid_loss_arr.append(valid_loss/len(valid_dl))
        
        if valid_loss<best_valid_loss:
            best_valid_loss = valid_loss
            patience = PATIENCE
            
            accelerator.wait_for_everyone()
            # if accelerator.is_main_process:
            #     tokenizer.save_pretrained('Span_Llama_Checkpoints/')
            # unwrapped_model = accelerator.unwrap_model(discriminator)
            state_dict = accelerator.get_state_dict(discriminator)
            torch.save({'model_dict':state_dict},f'{save_dir}/best_model_chkpt.pth.tar')
        else:
            patience -= 1
            accelerator.print(f"REDUCING PATIENCE...{patience}")

        if patience<=0:
            accelerator.print("RUNNING OUT OF PATIENCE... TERMINATING")
            break
    
    
    return train_loss_arr,valid_loss_arr
                

In [None]:
import os

from transformers import AutoModel

EPOCHS = 50
BATCH_SIZE = 8

# MPNet = AutoModel.from_pretrained(bert_chkpt).to(device)
discriminator_model = discriminator(bert_chkpt)

optimizer = torch.optim.AdamW(discriminator_model.parameters(),lr=1e-5)

save_dir = 'Detect_Span_FB_Llama_chkpts_1'
if not os.path.exists(save_dir):
    os.mkdir(save_dir)

In [None]:
train_loss, valid_loss = notebook_launcher(train,args=(discriminator_model,train_DL,valid_DL,EPOCHS,BATCH_SIZE,optimizer,5,save_dir),num_processes=2)

In [None]:
train_loss,valid_loss = train(discriminator_model,
                              train_DL,
                              valid_DL,
                              EPOCHS,
                              BATCH_SIZE,
                              optimizer,
                              PATIENCE=5,
                              save_dir=save_dir)

In [None]:
import json

with open('train_loss.json','w') as f:
    json.dump(train_loss,f)

with open('valid_loss.json','w') as f:
    json.dump(valid_loss,f)

In [None]:
train_loss_ds = np.array(train_loss)[np.round(np.linspace(0, len(train_loss) - 1, len(valid_loss))).astype(int)]
loss_df = pd.DataFrame({'train_loss':train_loss_ds , 'valid_loss':valid_loss})

In [None]:
from plotly import express as px
px.line(loss_df,y=['train_loss','valid_loss'])

In [None]:
discriminator.load_state_dict(torch.load('GenFB_BART_chkpts_1/Epoch_0_model_chkpt.pth.tar')['model_state'])

In [None]:
i = 0
for b in train_DL:
    out = discriminator.generate(inputs=b['input'][0:1,0].to(device),top_p=0.5)
    print(tokenizer.decode(b['input'][0:1,0][0],skip_special_tokens=True))
    print(tokenizer.decode(b['feedback'][0:1,0][0],skip_special_tokens=True))
    print(tokenizer.decode(out[0]))
    print("--------------------------------------------------------")
    i+=1
    if i>10:
        break