In [7]:
import numpy as np
import pandas as pd
import tqdm
from tqdm import tqdm_notebook
import torch
import plotly.express as px
from nltk import tokenize
import itertools
from torch import nn
import torch.nn.functional as F
import os

In [9]:
import nltk
nltk.download('punkt',download_dir='/home/jupyter/Ravi_new/nltk_data')

[nltk_data] Downloading package punkt to
[nltk_data]     /home/jupyter/Ravi_new/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [None]:
if torch.cuda.is_available():
    device = 'cuda:0'

## Data Preparation

In [None]:
def flatten_data(csv):
    df = pd.read_csv(csv)
    new_df = []#pd.DataFrame(columns=['id','prompt','gen_text','error','feedback','severity','text_before_span','span','text_after_span'])

    for i in tqdm.tqdm(range(len(df)),desc='unrolling data'):
        id = df.loc[i]['id']
        prompt = df.loc[i]['prompt']
        gen_text = df.loc[i]['generation']#.replace(u'\xa0', u' ').replace(u'  ', u' ')
        feedbacks = eval(df.loc[i]['responses'])

        for response in feedbacks:
            if len(response)==0:
                continue
            for r in response:

                error = r[0]
                feedback = r[1].replace("_SEP_",",").replace("_QUOTE_",'"')
                severity = r[2]
                beg = r[3]
                end = r[4]

                span = gen_text[beg:end]
                
                dic = {"id":id,
                      "prompt":prompt,
                      "gen_text":gen_text,
                      "error":error,
                      "feedback":feedback,
                      "severity":severity,
                      "span_beg":beg,
                      "span_end":end,
                      "span":span}

                new_df.append(dic)

    return pd.DataFrame(new_df)

In [None]:
data = flatten_data('../grouped_data.csv')

language_errors = ['Grammar_Usage', 'Off-prompt', 'Redundant', 'Self-contradiction', 'Incoherent']
data = data[data['error'].isin(language_errors)]
# data = data[data['severity']>1]

In [None]:
len(data)

In [None]:
data['span_len'] = data['span'].apply(lambda x: len(x))
#px.histogram(data['span_len'],nbins=100)

In [None]:
data[data['span_len']>20].shape

In [None]:
data = data[data['span_len']>20]

In [None]:
data['error'].value_counts()

In [None]:
data['gen_sentences'] = data['gen_text'].apply(lambda x: tokenize.sent_tokenize(x))
data['span_is_sentence'] = [1 if x in y else 0 for x,y in zip(data['span'],data['gen_sentences'])]

In [None]:
data['span_is_sentence'].value_counts()

In [None]:
i = 392
#data.loc[i]['gen_sentences'] = [s.strip() for s in data.loc[i]['gen_sentences']]
s = " ".join(data.iloc[i]['gen_sentences'])
s

In [None]:
data.iloc[i]['gen_text']

In [None]:
data.iloc[i]['span']

In [None]:
data.iloc[i][['span_beg','span_end']]

In [None]:
s[data.iloc[i]['span_beg']:data.iloc[i]['span_end']]==data.iloc[i]['span']

In [None]:
##check to see if the above technique is working fine for most data points
data['tech_works'] = [" ".join(data.iloc[i]['gen_sentences'])[data.iloc[i]['span_beg']:data.iloc[i]['span_end']] == data.iloc[i]['span'] for i in range(len(data))]

In [None]:
data['tech_works'].value_counts()

In [None]:
data = data[data['tech_works']==True]

In [None]:
data = data.reset_index()

In [None]:
def label_err_sentence(sentences,span_beg,span_end,multi_class=False,error_type=None):
    
    output = []
    total_len = sum([len(s) for s in sentences]) + len(sentences) - 1 #total sentences length + (sentences-1) spaces
    
    idx = 0
    
    if multi_class:
        label = error_type+1
    else:
        label = 1
    
    for s in sentences:
        if idx>=span_beg and idx<=span_end:
            output.append(label)
        elif idx+len(s)>=span_beg and idx+len(s)<=span_end:
            output.append(label)
        elif idx<=span_beg and idx+len(s)>=span_end:
            output.append(label)
        else:
            output.append(0)
        
        idx += len(s)+1
            
    return output

In [None]:
i = 101
data.iloc[i]['gen_sentences']

In [None]:
data.iloc[i]['span']

In [None]:
label_err_sentence(data.iloc[i]['gen_sentences'],data.iloc[i]['span_beg'],data.iloc[i]['span_end'],multi_class=True,error_type=1)

In [None]:
data['error'].value_counts()

In [None]:
ERROR_MAP = {'Redundant':0,
             'Off-prompt':1,
             'Grammar_Usage':2,
             'Incoherent':3,
             'Self-contradiction':4,
             'Needs_Google':5,
             'Technical_Jargon':6,
             'Commonsense':7,
             'Encyclopedic':8,
             'Bad_Math':9}

INV_ERROR_MAP = {v:k for k,v in ERROR_MAP.items()}

In [13]:
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader

#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    se = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return se #F.normalize(se, p=2, dim=1)

device = 'cuda:0'

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Load model from HuggingFace Hub
bart_chkpt = 'meta-llama/Llama-2-7b-chat-hf' #'decapoda-research/llama-13b-hf-int8'#'sentence-transformers/all-MiniLM-L6-v2' #'SpanBERT/spanbert-large-cased'# #'sentence-transformers/all-distilroberta-v1'
tokenizer = LlamaTokenizer.from_pretrained(bart_chkpt,cache_dir='/home/jupyter/Ravi_new/HF_cache')#'roberta-base'
bert_model = AutoModelForCausalLM.from_pretrained(bart_chkpt,cache_dir='/home/jupyter/Ravi_new/HF_cache')#, cache_dir='/home/jupyter/Ravi/HF_Cache/')
bert_model = bert_model.to(device)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
bert_model

In [18]:
os.environ['CUDA_LAUNCH_BLOCKING'] = '0'

ctxt = "Hey, how are you doing?"
r1 = "I'm doing fine."
r2 = "I'm doing fine. I'm doing fine. I'm doing fine. I'm doing fine. I'm doing fine."

# ctxt = 'Hey, how are you doing?'
# r1 = 'I am fine.'
# r2 = 'I am fine. I am fine. I am fine.'

tok_ctxt = tokenizer(ctxt, add_special_tokens=False)
tok_pad = [-100]*len(tok_ctxt['input_ids'][:])
tok_r1 = tokenizer(r1, add_special_tokens=False)
tok_r2 = tokenizer(r2, add_special_tokens=False)

#can use a stack structure over response and loop over it until empty to compute loss for each token in response

with torch.no_grad():
    out1 = bert_model(input_ids=torch.tensor(tok_ctxt.input_ids+tok_r1.input_ids[:]).unsqueeze(0).to('cuda:0'),
                      labels=torch.tensor(tok_pad+tok_r1.input_ids).unsqueeze(0).to('cuda:0'))
    out2 = bert_model(input_ids=torch.tensor(tok_ctxt.input_ids+tok_r2.input_ids[:]).unsqueeze(0).to('cuda:0'),
                      labels=torch.tensor(tok_pad+tok_r2.input_ids).unsqueeze(0).to('cuda:0'))
out1.loss, out2.loss

(tensor(1.4565, device='cuda:0'), tensor(0.7584, device='cuda:0'))

In [None]:
tokenizer(ctxt, add_special_tokens=False), tokenizer(r1, add_special_tokens=False).input_ids

In [None]:
out1.last_hidden_state.shape

In [None]:
tokenizer.bos_token,tokenizer.eos_token

In [None]:
tokenizer.all_special_tokens

In [None]:
sum(tokenizer('Hey there how are you doing?',return_length=True,add_special_tokens=True,padding='max_length',max_length=20).attention_mask)

In [None]:
class scarecrow_dataset(Dataset):
    
    def __init__(self,df,max_length=1024,MULTI_CLASS=False):
        self.df = df
        self.max_len = max_length
        self.data = []
        
        for i in tqdm.tqdm(range(len(self.df)),desc='vectorizing..'):
            
            label_err_sent = label_err_sentence(self.df.iloc[i]['gen_sentences'],self.df.iloc[i]['span_beg'],self.df.iloc[i]['span_end'],multi_class=MULTI_CLASS,error_type=ERROR_MAP[self.df.iloc[i]['error']])
            
            for j,s in enumerate(self.df.iloc[i]['gen_sentences']):
                d = {}
                d['id'] = torch.tensor(self.df.iloc[i]['id'])
                d['error'] = torch.tensor(ERROR_MAP[self.df.iloc[i]['error']])
                
                tok_prompt = tokenizer(self.df.iloc[i]['prompt'],return_token_type_ids=True,add_special_tokens=False)
                tok_gen_text = tokenizer(self.df.iloc[i]['gen_sentences'],return_token_type_ids=True,add_special_tokens=False)
                
                tok_input = [tokenizer.bos_token_id] + tok_prompt['input_ids'] + list(itertools.chain.from_iterable(tok_gen_text['input_ids'])) + [tokenizer.eos_token_id]
                
                if self.max_len - len(tok_input) < len(tok_gen_text['input_ids'][j]): ##the sentence of interest should fit after the prompt + gen_text for the model to have context
                    continue
                tok_sent_of_interest = tokenizer(s, add_special_tokens=False, max_length=self.max_len-len(tok_input), padding='max_length')
                tok_input += tok_sent_of_interest['input_ids']
                
                feedback = 'This sentence looks good.' if label_err_sent[j]==0 else self.df.iloc[i]['feedback']
                
                tok_feedback = tokenizer(feedback, 
                                         return_token_type_ids=True, 
                                         add_special_tokens=True, 
                                         return_length=True,
                                         max_length=self.max_len, 
                                         padding='max_length', 
                                         truncation='only_first')
                                
                d['input'] = torch.LongTensor(tok_input)
                d['feedback'] = torch.LongTensor(tok_feedback['input_ids'])
                d['index'] = torch.tensor(sum(tok_feedback.attention_mask)-1)
                d['class'] = torch.tensor(label_err_sent[j])
                
                self.data.append(d)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,idx):
        return self.data[idx]

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
train_ids , test_ids = train_test_split(data.index,test_size=0.1,shuffle=True,random_state=42)

In [None]:
train_df , test_df = data.loc[train_ids], data.loc[test_ids]
train_df , valid_df = train_df.iloc[:int(0.9*len(train_df))], train_df.iloc[int(0.9*len(train_df)):]

In [None]:
train_dataset = scarecrow_dataset(train_df,MULTI_CLASS=False)
valid_dataset = scarecrow_dataset(valid_df,MULTI_CLASS=False)
test_dataset  = scarecrow_dataset(test_df,MULTI_CLASS=False)

In [None]:
train_DL = DataLoader(train_dataset,batch_size=10,shuffle=True)
valid_DL = DataLoader(valid_dataset,batch_size=10,shuffle=True)
test_DL = DataLoader(test_dataset,batch_size=1,shuffle=False)

In [None]:
from transformers import BartForConditionalGeneration
bart_model = BartForConditionalGeneration.from_pretrained(bart_chkpt).to(device)
with torch.no_grad():
    for b in train_DL:
        print(b['index'].shape)
        output = bart_model(input_ids=b['input'].to(device),
                            decoder_input_ids=b['feedback'][:,:-1].to(device),
                            labels=b['feedback'][:,1:].to(device),
                            output_hidden_states=True).loss
        print(output)
        break
del bart_model

## Modelling

In [None]:
class binary_classifier(nn.Module):
    
    def __init__(self,bart_chkpt='facebook/bart-base',inp_dim=768,hidden_dims=None,num_classes=2,use_norm=False,do_softmax=False):
        super().__init__()
        
        self.bart_model = BartForConditionalGeneration.from_pretrained(bart_chkpt)
        
#         self.num_classes = num_classes
        
#         self.use_norm = use_norm
#         self.inp_layer = nn.Linear(inp_dim,hidden_dims[0])

#         hidden_layers = []
#         for i in range(len(hidden_dims)-1):
#             hidden_layers.append(nn.Linear(hidden_dims[i],hidden_dims[i+1]))
#             hidden_layers.append(nn.Dropout(p=0.2))
#             hidden_layers.append(nn.ReLU())
#         self.layers = nn.Sequential(*hidden_layers)

#         self.out_layer = nn.Linear(hidden_dims[-1],num_classes)
        
#         self.do_softmax = do_softmax
        
    
    def forward(self,inp,ref,labels):
        y = self.bart_model(input_ids=inp,decoder_input_ids=ref,labels=labels)
        return y

In [None]:
def train(classifier,train_dl,valid_dl,epochs,optimizer,PATIENCE=20,save_dir=None):

    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    
    classifier.train()
    
    loss_acc = 0
    num_batches = 0
    total_steps = 0
    best_valid_loss = np.inf
    patience = PATIENCE
    
    train_loss_arr,valid_loss_arr = [],[]
    
    optimizer.zero_grad()
    classifier.zero_grad()
    
    for E in range(epochs):
        
        for b in train_dl:
            
            y = classifier(b['input'].to(device),b['feedback'][:,:-1].to(device),b['feedback'][:,1:].to(device))
            loss = y.loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            loss_acc += loss.item()
            num_batches += 1
            total_steps += 1
            
            train_loss_arr.append(loss_acc/num_batches)
            
            if total_steps%100==0:
                print("Epoch:",E,"\t","Steps taken:",total_steps,"\tLoss:",loss_acc/num_batches)
            
        #print("Epoch:",E,"\t","Steps taken:",total_steps,"\tLoss:",loss_acc/num_batches)
        
        torch.save({'model_state':classifier.state_dict(),
                    'optimizer':optimizer.state_dict(),
                    'epoch':E},
                    f"{save_dir}/Epoch_{E}_model_chkpt.pth.tar")
        
        valid_loss = validate(classifier,valid_dl)
        valid_loss_arr.append(valid_loss/len(valid_dl))
        
        if valid_loss<best_valid_loss:
            best_valid_loss = valid_loss
            patience = PATIENCE
            
            torch.save({'model_state':classifier.state_dict(),
                        'optimizer':optimizer.state_dict(),
                        'epoch':E},
                        f"{save_dir}/best_model_chkpt.pth.tar")
        else:
            patience -= 1
            print(f"REDUCING PATIENCE...{patience}")

        if patience<=0:
            print("RUNNING OUT OF PATIENCE... TERMINATING")
            break
    
    
    return train_loss_arr,valid_loss_arr
                

In [None]:
def validate(classifier,valid_dl):
    
    classifier.eval()
    valid_loss = 0
    with torch.no_grad():
        for b in valid_dl:
            y = classifier(b['input'].to(device),b['feedback'][:,:-1].to(device),b['feedback'][:,1:].to(device))
            loss = y.loss
            valid_loss += loss.item()
            
    print("Validation Loss:",valid_loss)
    return valid_loss

In [None]:
classifier = binary_classifier(bart_chkpt=bart_chkpt, inp_dim=768, hidden_dims=[768,256], num_classes=2, use_norm=True)
classifier.to(device);

#### Run if wish to freeze bert, finetune clf head
classifier.load_state_dict(torch.load('TempFB_BART_1/best_model_chkpt.pth.tar')['model_state'])
classifier.bert_model.requires_grad_(False)

In [None]:
optimizer = torch.optim.AdamW(classifier.parameters(),lr=1e-5)

save_dir = 'GenFB_BART_chkpts_1'
if not os.path.exists(save_dir):
    os.mkdir(save_dir)

train_loss,valid_loss = train(classifier,train_DL,valid_DL,50,optimizer,PATIENCE=5,save_dir=save_dir)

In [None]:
px.line(valid_loss)

In [None]:
i = 0

model = classifier.bart_model
model.eval()

with torch.no_grad():
    for b in test_DL:
        gen = model.generate(b['input'].to(device))
        print('\n---------------------------------\nContext: ',tokenizer.decode(b['input'][0],skip_special_tokens=True),'\n\n')
        print('Human Feedback: ',tokenizer.decode(b['feedback'][0],skip_special_tokens=True),'\n\n')
        print('Generated Feedback: ',tokenizer.decode(gen[0],skip_special_tokens=True),'\n--------------------------------\n')
        i += 1
        if i>50: break

In [None]:
from torchmetrics.functional import precision_recall as PR
from torchmetrics.classification import BinaryJaccardIndex as JI

#bji = JI().to('cuda:0')

def test(classifier,test_dl,binary=False,chkpt=None):
    
    P = []
    R = []
    acc = 0
    gt,preds = [],[]
    error_type = []
#     JIdx = []
    
    if chkpt!=None:
        classifier.load_state_dict(torch.load(chkpt)['model_state'])
    classifier.eval()
    
    with torch.no_grad():
        for b in tqdm_notebook(test_dl):

            logits = classifier(b['input'].to(device),b['feedback'].to(device),b['index'])
#             print(logits)
#             break
            out = logits.argmax(dim=-1)
            gl = b['class']
            if binary:
                out = torch.clamp(out,max=1)
                gl = torch.clamp(gl,max=1)
            #print(out)
            gt.append(gl.item())
            preds.append(out.item())
            if out.item()==gl.item():
                acc+=1
            p_r = PR(preds=logits, target=gl.to(device))


            #JIdx.append(ji.item())
            P.append(p_r[0].item())
            R.append(p_r[1].item())
            error_type.append(INV_ERROR_MAP[b['error'].item()])
            #print(p_r)
    
    return P,R,acc,gt,preds,error_type#acc/len(test_dl)

def test_2(bin_classifier,multi_classifier,test_dl,bin_chkpt=None,multi_chkpt=None):
    
    P = []
    R = []

    acc = 0
    gt,preds = [],[]
    error_type = []

    if bin_chkpt!=None:
        bin_classifier.load_state_dict(torch.load(bin_chkpt)['model_state'])
    if multi_chkpt!=None:
        multi_classifier.load_state_dict(torch.load(multi_chkpt)['model_state'])

    bin_classifier.eval()
    multi_classifier.eval()

    with torch.no_grad():
        for b in tqdm_notebook(test_dl):

            logits = bin_classifier(b['input'].to(device),b['sent_mask'].to(device))
#             print(logits)
#             break
            out = logits.argmax(dim=-1)
            gl = b['class']
            if out.item()==1:
                m_logits = multi_classifier(b['input'].to(device),b['sent_mask'].to(device))
                m_out = m_logits.argmax(dim=-1)
                gt.append(gl.item())
                preds.append(m_out.item())
                if m_out.item()==gl.item():
                    acc+=1
                error_type.append(INV_ERROR_MAP[b['error'].item()])
            else:
                #gl = torch.clamp(b['class'],min=6)
                #out = torch.clamp(out,min=6)
                gt.append(gl.item())
                preds.append(out.item())
                if out.item()==gl.item():
                    acc+=1
                error_type.append(INV_ERROR_MAP[b['error'].item()])


            #print(out)
#                 gt.append(gl.item())
#                 preds.append(out.item())
#                 if out.item()==gl.item():
#                     acc+=1
#                 #p_r = PR(preds=logits, target=gl.to(device))


#                 #JIdx.append(ji.item())
# #                 P.append(p_r[0].item())
# #                 R.append(p_r[1].item())
#                 error_type.append(INV_ERROR_MAP[b['error'].item()])
#                 #print(p_r)

    return _,_,acc,gt,preds,error_type#acc/len(test_dl)

In [None]:
# P,R,acc,GT,PREDS,error_type = test_2(bin_classifier,multi_classifier,test_DL,bin_chkpt='classifier_chkpts_E2E_3/best_model_chkpt.pth.tar',multi_chkpt='classifier_chkpts_E2E_7/best_model_chkpt.pth.tar')
P,R,acc,GT,PREDS,error_type = test(classifier,test_DL,binary=True,chkpt='FB_BART_chkpts_1/best_model_chkpt.pth.tar')
print(acc)

In [None]:
from sklearn.metrics import precision_score,recall_score

out_df = pd.DataFrame()
out_df['gt'] = GT
out_df['preds'] = PREDS
out_df['error_type'] = error_type

for e in ERROR_MAP.keys():
    gt = out_df[out_df['error_type']==e]['gt']
    preds = out_df[out_df['error_type']==e]['preds']
    print('\nError:',e)
    print('P:',precision_score(gt,preds,average='macro'))
    print('R:',recall_score(gt,preds,average='macro'))

In [None]:
gt = out_df['gt']
preds = out_df['preds']

from sklearn.metrics import precision_score
precision_score(gt,preds,average='macro')

In [None]:
from sklearn.metrics import recall_score
recall_score(gt,preds,average='macro')

In [None]:
from sklearn.metrics import f1_score
f1_score(gt,preds,average='macro')

In [None]:
from sklearn.metrics import ConfusionMatrixDisplay,confusion_matrix
cm = confusion_matrix(gt,preds)
ConfusionMatrixDisplay(cm).plot()

In [None]:
acc/len(test_DL)