In [7]:
from transformers import AutoTokenizer, T5ForConditionalGeneration
import torch
import os

In [8]:
best_model_path = "dailydialog_v3/large_key_m10/checkpoint-38032"

In [9]:
class Model(torch.nn.Module):
    def __init__(self, model_name, tokenizer, use_key, hidden_size):
        super(Model, self).__init__()
        self.lm = T5ForConditionalGeneration.from_pretrained(model_name)
        self.lm.resize_token_embeddings(len(tokenizer))
        
        if(use_key):
            #self.bow_head = torch.nn.Linear(hidden_size, tokenizer.vocab_size, bias=False)
            self.bow_head = torch.nn.Linear(hidden_size, len(tokenizer), bias=False)
        
    def forward(self, input_ids, attention_mask, labels, key_ids = None):
        if(use_key):
            lm_out = self.lm(input_ids=input_ids, attention_mask=attention_mask, 
                           labels=labels, output_hidden_states=True)
            hidden = lm_out.decoder_hidden_states[-1]
            h = torch.permute(hidden, (1, 0, 2))[0]
            fc1 = self.bow_head(h)
            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
            
            b_size = key_ids.size(0)
            bow_loss = 0
            for i in range(b_size):
                bow_ids = key_ids[i]
                bow_logits = fc1[i]
                bow_logits = bow_logits.expand(bow_ids.size(0), -1)
                #b_loss = torch.nan_to_num(loss_fct(bow_logits, bow_ids))
                b_loss = loss_fct(bow_logits, bow_ids)
                bow_loss+=b_loss
            bow_loss = bow_loss/b_size
            loss = (bow_wt*bow_loss) + lm_out.loss
        else:
            lm_out = self.lm(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = lm_out.loss
        return {'loss': loss}

In [13]:
MODEL_CKPT = "t5-large"
tokenizer = AutoTokenizer.from_pretrained(best_model_path)
raw_model = Model(MODEL_CKPT, tokenizer, True, 1024)
m_path = os.path.join(best_model_path, "pytorch_model.bin")
raw_model.load_state_dict(torch.load(m_path, map_location=torch.device('cpu')))
raw_model.eval()
print("Model Loaded")

Model Loaded


In [67]:
ctx = "Hey man , you wanna buy some weed ?<eou>Some what ?<eou>Weed ! You know ? Pot , Ganja , Mary Jane some chronic !<eou>Oh , umm , no thanks .<eou>I also have blow if you prefer to do a few lines .<eou>"
#output = tokenizer(ctx, max_length=512, truncation=True, return_tensors="pt")
ctx = "Hey man , you wanna buy some weed ?<eou>Some what ?<eou>Weed ! You know ? Pot , Ganja , Mary Jane some chronic !<eou>"
#output = tokenizer(ctx, max_length=512, truncation=True, padding = 'max_length', return_tensors="pt")
output = tokenizer(ctx, max_length=512, truncation=True, return_tensors="pt")

dec_inp = torch.tensor([[0]], dtype=output.input_ids.dtype)
print(output["input_ids"].shape)
print(dec_inp)

torch.Size([1, 43])
tensor([[0]])


In [68]:
with torch.no_grad():
    input_ids = output["input_ids"]
    attention_mask = output["attention_mask"]
    out = raw_model.lm.generate(input_ids=input_ids, attention_mask=attention_mask, 
                                decoder_input_ids=dec_inp,
                                num_beams=5, max_new_tokens=41, min_new_tokens=12, 
                            length_penalty=0.1)
    
    
    utt = tokenizer.batch_decode(out, skip_special_tokens=True)
    print(utt)
    
    lm_out = raw_model.lm(input_ids=input_ids, attention_mask=attention_mask, 
                           decoder_input_ids=dec_inp, output_hidden_states=True)
    
    hidden = lm_out.decoder_hidden_states[-1]
    print(hidden.shape)
    h = torch.permute(hidden, (1, 0, 2))[0]
    fc1 = raw_model.bow_head(h)
    print(fc1.shape)
    
    val, indices = torch.topk(fc1, 8)
    
    lst_idx = indices.tolist()
    print(lst_idx)
    lst_tok = tokenizer.convert_ids_to_tokens(lst_idx[0])
    print(lst_tok)
    
    

["Yeah, that's what I was thinking."]
torch.Size([1, 1, 1024])
torch.Size([1, 32102])
[[32101, 3, 7, 17945, 63, 15, 9, 207]]
['<nok>', '▁', 's', '▁yeah', 'y', 'e', 'a', '▁good']


In [49]:
x = torch.tensor([[-1,-2,4,-100, -20], [1,8,4,-100, 20]])
print(x)
val, indices = torch.topk(x, 3)
print(val)
print(indices)

tensor([[  -1,   -2,    4, -100,  -20],
        [   1,    8,    4, -100,   20]])
tensor([[ 4, -1, -2],
        [20,  8,  4]])
tensor([[2, 0, 1],
        [4, 1, 2]])
