In [None]:
from copyreg import pickle
from operator import le
from readline import parse_and_bind
import torch
from transformers import AutoConfig,AutoModelForSeq2SeqLM,AutoTokenizer
from dataclasses import dataclass,field 
from typing import List
import math
import string





def setup_model(task='sum', dataset='xsum', model_name='facebook/bart-large-xsum', device_name='cuda:0'):
    device = torch.device(device_name)
    print(model_name)
    config = AutoConfig.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    if task == 'custom':
        # you need to store the input under the path_dataset folder
        dec_prefix = [tokenizer.eos_token_id]
        with open(os.path.join(dataset, 'input.txt'), 'r') as fd:
            slines = fd.read().splitlines()
        with open(os.path.join(dataset, 'output.txt'), 'r') as fd:
            tlines = fd.read().splitlines()
        dataset = zip(slines, tlines)
    elif task == 'sum':
        logging.info('Loading dataset')
        if dataset == 'xsum':
            dataset = load_dataset("xsum", split='validation')
        elif dataset == 'cnndm':
            raise NotImplementedError("not supported")
            dataset = load_dataset("cnn_dailymail", split='validation')
            print("CNNDM mean token in ref 56")
        dec_prefix = [tokenizer.eos_token_id]
    elif task == 'mt1n':
        from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
        model = MBartForConditionalGeneration.from_pretrained(
            "facebook/mbart-large-50-one-to-many-mmt")
        tokenizer = MBart50TokenizerFast.from_pretrained(
            "facebook/mbart-large-50-one-to-many-mmt", src_lang="en_XX")
        assert dataset.startswith('en')
        tgt_lang = dataset[3:]
        dataset = read_mt_data(name=dataset)

        from transformers.models.mbart.tokenization_mbart import FAIRSEQ_LANGUAGE_CODES
        match = [x for x in FAIRSEQ_LANGUAGE_CODES if x.startswith(tgt_lang)]
        assert len(match) == 1
        lang = match[0]
        logging.info(f"Lang: {lang}")
        dec_prefix = [tokenizer.eos_token_id, tokenizer.lang_code_to_id[lang]]
        logging.info(f"{tokenizer.decode(dec_prefix)}")
    elif task == 'mtn1':
        from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
        model = MBartForConditionalGeneration.from_pretrained(
            "facebook/mbart-large-50-many-to-one-mmt", )
        tokenizer = MBart50TokenizerFast.from_pretrained(
            "facebook/mbart-large-50-many-to-one-mmt")
        # dataset should be like "xx-en"
        assert dataset.endswith('-en')
        src_lang = dataset[:2]
        from transformers.models.mbart.tokenization_mbart import FAIRSEQ_LANGUAGE_CODES
        match = [x for x in FAIRSEQ_LANGUAGE_CODES if x.startswith(src_lang)]
        assert len(match) == 1
        lang = match[0]
        tokenizer.src_lang = lang
        dataset = read_mt_data(name=dataset)
        dec_prefix = [tokenizer.eos_token_id,
                      tokenizer.lang_code_to_id["en_XX"]]
        logging.info(f"{tokenizer.decode(dec_prefix)}")
    model = model.to(device)
    return tokenizer, model, dataset,dec_prefix

import sys

import logging
from datasets import load_dataset
import torch
import random
import os

import pickle
import time
device = 'cuda:0'
tokenizer, model, dataset, dec_prefix= setup_model(device_name=device)

In [None]:

@torch.no_grad()
def run_inference_step(model, input_ids, attention_mask=None, decoder_input_ids=None, targets=None, device='cuda:0', output_dec_hid=False, T=1):
    decoder_input_ids = decoder_input_ids.to(device)
    input_ids = input_ids.to(device)
    if attention_mask is not None:
        attention_mask = attention_mask.to(device)
    if decoder_input_ids.size()[0] != input_ids.size()[0]:
        target_batch_size = decoder_input_ids.size()[0]
        batch_input_ids = input_ids.expand(target_batch_size, input_ids.size()[1])
    else:
        batch_input_ids = input_ids
    assert decoder_input_ids.size()[0] == batch_input_ids.size()[0]

    model_inputs = {"input_ids": batch_input_ids,
                    "attention_mask": attention_mask,
                    "decoder_input_ids": decoder_input_ids,
                    }

    outputs = model(**model_inputs,
                    output_hidden_states=output_dec_hid,
                    use_cache=False, return_dict=True)

    # batch, dec seq, vocab size
    next_token_logits = outputs.logits[:, -1, :]
    if targets is not None:
        targets = targets.to(device)
        loss = torch.nn.functional.cross_entropy(
            input=next_token_logits, target=targets, reduction='none')
    else:
        loss = 0

    prob = torch.nn.functional.softmax(next_token_logits/T, dim=-1)

    return prob, next_token_logits, loss

In [None]:
doc = "BART model pre-trained on English language, and fine-tuned on CNN Daily Mail. It was introduced in the paper BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension by Lewis et al. and first released in [this repository (https://github.com/pytorch/fairseq/tree/master/examples/bart)."
ref = "BART is a transformer encoder-encoder (seq2seq) model with a bidirectional (BERT-like) encoder and an autoregressive (GPT-like) decoder."
doc_input_ids = torch.tensor(tokenizer(doc)['input_ids'],dtype=torch.long,device=device).unsqueeze(0)
ref_ids_list = tokenizer(ref)['input_ids'][1:-1]
ref_len = len(ref_ids_list)
ref_input_ids = torch.tensor(ref_ids_list ,dtype=torch.long,device=device)

print(doc_input_ids,ref_ids_list, ref_input_ids)

dec_prefixes_id = [tokenizer.eos_token_id]


for t in range(ref_len):
    target  = ref_ids_list[t]
    # print(dec_prefixes_id)
    dec_input_tensors = torch.tensor(dec_prefixes_id, dtype=torch.long, device=device).unsqueeze(0)
    prob, next_token_logits, loss = run_inference_step(model=model, input_ids=doc_input_ids, decoder_input_ids=dec_input_tensors, device=device,)
    oracle_prob = prob[0][target].tolist()
    dec_prefixes_id.append(target)
    print(oracle_prob)
    


In [1]:

from pyvis.network import Network

g = Network()
g.add_nodes([1, 2, 3],
            value=[1, 1, 1],
            title=["I am node 1", "node 2 here", "and im node 3"],
            x=[0, 100, 100],
            y=[0, 100, 200], label=["NODE 1", "NODE 2", "NODE 3"],
            color=["#00ff1e", "#162347", "#dd4b39"])

g.show('just_nodes.html')

g.add_edge(1, 2)
g.add_edge(1, 3)

g.show('with_edges.html')

for n in g.nodes:
    n.update({'physics': False})

g.show('example.html')

In [22]:
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification

tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
sent_model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")

inputs = tokenizer("Story Cloze Test is a new commonsense reasoning framework for evaluating story understanding, story generation, and script learning. I kinda hate it.", return_tensors="pt")
with torch.no_grad():
    logits = sent_model(**inputs).logits

predicted_class_id = logits.argmax().item()
sent_model.config.id2label[predicted_class_id]

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.bias', 'vocab_projector.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.bias', 'classifier.weight', 'pre_classi

'LABEL_0'

In [3]:
from transformers import GPT2LMHeadModel
model = GPT2LMHeadModel.from_pretrained(
            "/export/home/experimental/neurologic_decoding/gpt2-large/checkpoint-1800",
        )
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

In [20]:

inputs = tokenizer("A god leaps to his feet", return_tensors="pt")
outputs = model(**inputs, labels=inputs["input_ids"])
logits = outputs.logits

last = logits[0,-1,:]
x = torch.argmax(last)
print (x)
print(tokenizer.decode(x.tolist()))


tensor(290)
 and
