# Training models for all Experiments
##### numbers of epoches are based on analysis of k-fold results
##### all hyperparams are fixed

In [15]:
from transformers import BartTokenizer
import torch
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW
#from bart import SimpleBart
from experiments.experiment_type import ExperimentType
from dataset.scan_dataset import ScanDatasetHF

device = 'cuda'#'cpu'

In [16]:
# MODEL DEFINITION
import torch
from torch import nn
from transformers import BartModel, BartForConditionalGeneration
from adapters import AutoAdapterModel

def shift_tokens_right(input_ids: torch.Tensor, pad_token_id = 1, decoder_start_token_id = 2):
    """
    Shift input ids one token to the right.
    """

    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
    shifted_input_ids[:, 0] = decoder_start_token_id

    if pad_token_id is None:
        raise ValueError("self.model.config.pad_token_id has to be defined.")
    # replace possible -100 values in labels by `pad_token_id`
    shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

    return shifted_input_ids


class SimpleBart(nn.Module):
    def __init__(self, out_vocab_size, input_length = 120):
        super().__init__()
        #self.bart = AutoAdapterModel.from_pretrained('facebook/bart-base')
        #self.bart = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
        self.bart = BartModel.from_pretrained('facebook/bart-base')
        #self.up = nn.Linear(768, vocab_size)
        #self.post = nn.Linear(vocab_size, vocab_size)

        self.up = nn.Linear(768, out_vocab_size, device=device)
        

    @property
    def adaptable(self):
        return self.bart

    #def forward(self, in_ids, in_mask, tgt_ids, tgt_mask):
    def forward(self, kwargs):
        """
        Perform a forward pass through the model.
        NOT autoregressive
        Args:
            in_ids (torch.Tensor): Input IDs for the encoder.
            in_mask (torch.Tensor): Attention mask for the encoder input.
            tgt_ids (torch.Tensor): Input IDs for the decoder.
            tgt_mask (torch.Tensor): Attention mask for the decoder input.
        Returns:
            torch.Tensor: The output of the model after passing through the encoder and decoder.
        """

        #x = self.bart( **kwargs).logits

        x = self.bart( **kwargs).last_hidden_state

        x = self.up(x)
        #x = self.post(x)
        return x

In [17]:
# LR, BATCH_SIZE
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
batch_size = 32
lr = 0.001#0.0001
w_decay = 0.00001
criterion = CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)


In [34]:
#DATA
e_type = ExperimentType.E_1_1
data_paths = e_type.get_data_paths()

train_dataset = ScanDatasetHF(data_paths["train"], tokenizer)
test_dataset = ScanDatasetHF(data_paths["test"], tokenizer)
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True
                          )
test_loader = DataLoader(test_dataset,
                          batch_size=batch_size,
                          shuffle=True
                          )

output_vocab = torch.concatenate([train_dataset.output_vocab,test_dataset.output_vocab])
output_vocab = torch.sort(torch.unique(output_vocab))
_outvoc2voc = dict(zip(output_vocab.indices.numpy(),output_vocab.values.numpy()))
_voc2outvoc = dict(zip(output_vocab.values.numpy(),output_vocab.indices.numpy()))
output_vocab =output_vocab.values


def outvoc2voc(in_tensor):
    out_tensor = torch.zeros(in_tensor.shape, dtype=torch.int64, device=device)    
    for j in range(in_tensor.shape[0]):
        out_tensor[j] =  torch.tensor([_outvoc2voc[i] for i in in_tensor[j].cpu().numpy()], dtype=torch.int64, device=device)
    return out_tensor
def voc2outvoc(in_tensor):
    out_tensor = torch.zeros(in_tensor.shape, dtype=torch.int64, device=device)  
    #print(in_tensor)
    for j in range(in_tensor.shape[0]):
        out_tensor[j] = torch.tensor([_voc2outvoc[i] for i in in_tensor[j].cpu().numpy()], dtype=torch.int64, device=device)
    return out_tensor

In [27]:
#MODEL INITIALIZATION
import adapters

model = SimpleBart(out_vocab_size = output_vocab.size(0))
model = model.to(device)


for name, param in model.bart.named_parameters():
    if "bart"in name:
        param.requires_grad = False




adapters.init(model.adaptable)

from adapters import SeqBnInvConfig
config = SeqBnInvConfig()
model.adaptable.add_adapter("lang_adapter", config=config)
model.adaptable.set_active_adapters("lang_adapter")
model.adaptable.train_adapter("lang_adapter")

# from adapters import PrefixTuningConfig
# config = PrefixTuningConfig(flat=False, prefix_length=30)
# model.adaptable.add_adapter("prefix_tuning", config=config)
# model.adaptable.set_active_adapters("prefix_tuning")
# model.adaptable.train_adapter("prefix_tuning")



# from adapters import IA3Config
# config = IA3Config()
# model.adaptable.add_adapter("ia3_adapter", config=config)
# model.adaptable.set_active_adapters("ia3_adapter")
# model.adaptable.train_adapter("ia3_adapter")

#model.load_state_dict(torch.load("tbag/bart_1_1_5k.pth"))
model = model.to(device)



all_cnt = 0
trainable_cnt = 0
for name, param in model.named_parameters():
    all_cnt += param.numel()
    if param.requires_grad:
        trainable_cnt += param.numel()
        print(name, 'size:', param.numel())

#model.bart.lm_head.weight.requires_grad = True

print('prcnt of trainable:', trainable_cnt/(all_cnt-trainable_cnt))
print("All in millions:", all_cnt/1000000)
print("Trainable in millions:", trainable_cnt/1000000)


bart.encoder.layers.0.output_adapters.adapters.lang_adapter.adapter_down.0.weight size: 36864
bart.encoder.layers.0.output_adapters.adapters.lang_adapter.adapter_down.0.bias size: 48
bart.encoder.layers.0.output_adapters.adapters.lang_adapter.adapter_up.weight size: 36864
bart.encoder.layers.0.output_adapters.adapters.lang_adapter.adapter_up.bias size: 768
bart.encoder.layers.1.output_adapters.adapters.lang_adapter.adapter_down.0.weight size: 36864
bart.encoder.layers.1.output_adapters.adapters.lang_adapter.adapter_down.0.bias size: 48
bart.encoder.layers.1.output_adapters.adapters.lang_adapter.adapter_up.weight size: 36864
bart.encoder.layers.1.output_adapters.adapters.lang_adapter.adapter_up.bias size: 768
bart.encoder.layers.2.output_adapters.adapters.lang_adapter.adapter_down.0.weight size: 36864
bart.encoder.layers.2.output_adapters.adapters.lang_adapter.adapter_down.0.bias size: 48
bart.encoder.layers.2.output_adapters.adapters.lang_adapter.adapter_up.weight size: 36864
bart.enco

In [28]:
# CONFIG
from math import e


training_config = {
    "name": e_type.name,
    "tokenizer" : tokenizer,
    "model" : model,
    "evaluator" : None,
    "optimizer" : None,
    "grad_clip" : 5.0,
    "lr" : lr,
    "scheduler" : None,
    "criterion" : criterion,
    "train_dataset" : train_dataset,
    "test_dataset" : test_dataset,
    "train_loader" : train_loader,
    "test_loader" : test_loader,
    "batch_size" : batch_size,
    "max_steps" : 10 * 1000,#0,#//batch_size,
    "max_epochs" : None,
    "evaluation_interval" : 50,
    "model_save_interval" : 50,
    "detailed_logging" : True,
    "use_tensorboard" : False,
    "tensorboard_dir" : None,
    "model_save_dir" : None,

}

In [29]:
# TRAINING LOOP
from os import name
from turtle import mode
from tqdm.notebook import tqdm
from tensorboardX import SummaryWriter
from math import ceil
from torch.nn.utils import clip_grad_norm_

from dataset import vocab

def _train(config):
    name = config["name"]
    tokenizer = config["tokenizer"]
    model = config["model"]
    evaluator = config["evaluator"]
    optimizer = config["optimizer"]
    scheduler = config["scheduler"]
    grad_clip = config["grad_clip"]
    lr = config["lr"]
    criterion = config["criterion"]
    train_dataset = config["train_dataset"]
    test_dataset = config["test_dataset"]
    train_loader = config["train_loader"]
    test_loader = config["test_loader"]
    batch_size = config["batch_size"]
    max_steps = config["max_steps"]
    epochs = config["max_epochs"]
    evaluation_interval = config["evaluation_interval"]
    model_save_interval = config["model_save_interval"]
    detailed_logging = config["detailed_logging"]
    use_tensorboard = config["use_tensorboard"]
    tensorboard_dir = config["tensorboard_dir"]
    model_save_dir = config["model_save_dir"]

    optimizer = AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=w_decay,
        )
    
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.1)

    r_indx = name.find("_rep")
    org_name_base,org_rep_info = name[:r_indx], name[r_indx:]


    name = org_name_base + org_rep_info

    if use_tensorboard:
        writer = SummaryWriter(tensorboard_dir)

    # if model_weights_path is not None:
    #     _load_weights()

    grad_clip = grad_clip
    batch_size = batch_size
    
    max_epoch = epochs if epochs is not None else ceil(max_steps/len(train_dataset))
    max_steps = max_steps if max_steps is not None else (max_epoch*len(train_loader)*batch_size)+1
    max_batches = max_steps//batch_size
    batch_num = 0
    step_num = 0
    batch_step = 0

    eval_interval = evaluation_interval
    model_sv_interval = model_save_interval

    

    if detailed_logging:
        print("Training started for experiment: ", name)

    
    t ,_,_,_,_,_ = train_dataset[0]
    
    tgt_empty = torch.full((batch_size,t.shape[0]), tokenizer.pad_token_id, device=device)
    tgt_empty_msks = torch.full(tgt_empty.shape, 0, device=device)
    tgt_empty_msks[:,0] = 1

    # TRAINING LOOP
    epoch_progress = tqdm(range(max_epoch), desc="EPOCH")
    for epoch in epoch_progress:
        if max_steps is not None and step_num >= max_steps:
                print("Early Stopping: Max steps reached")
                break
        total_loss = 0
        model.train()
        #batch_bar = tqdm(train_loader, desc="BATCH")
        #for batch in batch_bar:
        batch_step = 0
        for batch in train_loader:
            if max_steps is not None and step_num >= max_steps:
                break

            inputs,inputs_msks, dec_in, dec_in_msks, targets, targets_msks = batch
            optimizer.zero_grad()

            inputs = inputs.to(device)
            inputs_msks = inputs_msks.to(device)
            dec_in = dec_in.to(device)
            dec_in_msks = dec_in_msks.to(device)
            dec_in_msks[:,0] = 1
            targets = targets.to(device)
            targets_msks = targets_msks.to(device)

            


            #out = model(inputs,inputs_msks,dec_in, dec_in_msks)

            if True: #batch_num%2: #step_num < max_steps//2:
                m = {
                    "input_ids": inputs,
                    "attention_mask" : inputs_msks,
                    "decoder_input_ids" : shift_tokens_right(voc2outvoc(dec_in)), #None, #targets,
                    "decoder_attention_mask" : shift_tokens_right(dec_in_msks), #WRONG!!!!
                    "head_mask" : None,
                    "decoder_head_mask" : None,
                    "cross_attn_head_mask" : None,
                    "encoder_outputs" : None,
                    "past_key_values" : None,
                    "inputs_embeds" : None,
                    "decoder_inputs_embeds" : None,
                    "use_cache" : None,
                    "output_attentions" : None,
                    "output_hidden_states" : None,
                    "return_dict" : None,
                }
            else:
                m = {
                    "input_ids": inputs,
                    "attention_mask" : inputs_msks,
                    "decoder_input_ids" : shift_tokens_right(tgt_empty), #None, #targets,
                    "decoder_attention_mask" : tgt_empty_msks, # None, #targets_msks,
                    "head_mask" : None,
                    "decoder_head_mask" : None,
                    "cross_attn_head_mask" : None,
                    "encoder_outputs" : None,
                    "past_key_values" : None,
                    "inputs_embeds" : None,
                    "decoder_inputs_embeds" : None,
                    "use_cache" : None,
                    "output_attentions" : None,
                    "output_hidden_states" : None,
                    "return_dict" : None,
                }
            out = model(m)

            targets_cmpr = voc2outvoc(targets)

            loss = criterion(out.permute(0, 2, 1), targets_cmpr)
            loss.backward()
            clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()
            if scheduler is not None:
                scheduler.step()

            total_loss += loss.item()
            batch_num += 1
            batch_step += 1
            step_num += batch_size

            if detailed_logging or batch_num == len(train_loader):
                print(f"Epoch {epoch+1}/{max_epoch} Batch {batch_num}/{max_batches} Trining Loss: {total_loss / (batch_num)} Step: {step_num}/{max_steps}" )

            

        epoch += 1
        if detailed_logging or epoch == max_epoch:
            print(f"Epoch {epoch}/{max_epoch} Batch {batch_num}/{max_batches} Trining Loss: {total_loss / (batch_num + 1)}")

        # EVALUATION 
        if epoch % eval_interval == 0 or epoch == max_epoch or max_steps is not None and step_num >= max_steps:
            pass
            # model.eval()
            # with torch.no_grad():
            #     total_loss = 0
            #     for batch in test_loader:
            #         inputs,inputs_msks, dec_in, dec_in_msks, targets, targets_msks = batch
            #         inputs = inputs.to(device)
            #         inputs_msks = inputs_msks.to(device)
            #         targets = targets.to(device)
            #         targets_msks = targets_msks.to(device)
            #         dec_in = dec_in.to(device)
            #         dec_in_msks = dec_in_msks.to(device)

            #         #out = model(inputs,inputs_msks,None, None)
            #         m = {
            #             "input_ids": inputs,
            #             "attention_mask" : inputs_msks,
            #             "decoder_input_ids" : None,
            #             "decoder_attention_mask" : None,
            #             "head_mask" : None,
            #             "decoder_head_mask" : None,
            #             "cross_attn_head_mask" : None,
            #             "encoder_outputs" : None,
            #             "past_key_values" : None,
            #             "inputs_embeds" : None,
            #             "decoder_inputs_embeds" : None,
            #             "use_cache" : None,
            #             "output_attentions" : None,
            #             "output_hidden_states" : None,
            #             "return_dict" : None,
            #         }
            #         out = model(m)
            #         loss = criterion(out.permute(0, 2, 1), targets)
            #         total_loss += loss.item()
            #     print(f"Epoch {epoch}/{max_epoch} Validation Loss: {total_loss / len(test_loader)}")


        #     result = evaluate_model_batchwise(model,test_dataset, test_loader, test_dataset.vocab, device)
        #     #result = EvaluationResult(result, name+fold_info+f"_epoch_{epoch}", e_type)
        #     result = EvaluationResult(result, name+f"_epoch_{epoch}", e_type)
        #     if detailed_logging or epoch == max_epoch:
        #         result.print()
        #     result_container.append_results(result)
        
        if epoch % model_sv_interval == 0 or epoch == max_epoch and not model_save_dir is None:
            buff = name
            name = org_name_base + f"_epoch_{epoch}" + org_rep_info
            torch.save(model.state_dict(),model_save_dir+'/'+name)
            name = buff


        if use_tensorboard:
            writer.add_scalar(tag = 'TrainLoss',
                                scalar_value = total_loss,
                                global_step = epoch)
        }
    if use_tensorboard:
        writer.close()
    return model

In [35]:
_train(training_config)

print("Finished Training")

Training started for experiment:  E_1_1


EPOCH:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 1/1 Batch 1/312 Trining Loss: 3.4110898971557617 Step: 32/10000
Epoch 1/1 Batch 2/312 Trining Loss: 3.123060464859009 Step: 64/10000
Epoch 1/1 Batch 3/312 Trining Loss: 3.0406734148661294 Step: 96/10000
Epoch 1/1 Batch 4/312 Trining Loss: 2.942711591720581 Step: 128/10000
Epoch 1/1 Batch 5/312 Trining Loss: 2.8641037940979004 Step: 160/10000
Epoch 1/1 Batch 6/312 Trining Loss: 2.7966291507085166 Step: 192/10000
Epoch 1/1 Batch 7/312 Trining Loss: 2.7455385412488664 Step: 224/10000
Epoch 1/1 Batch 8/312 Trining Loss: 2.7055512964725494 Step: 256/10000
Epoch 1/1 Batch 9/312 Trining Loss: 2.6761597792307534 Step: 288/10000
Epoch 1/1 Batch 10/312 Trining Loss: 2.6482025623321532 Step: 320/10000
Epoch 1/1 Batch 11/312 Trining Loss: 2.6202081116763027 Step: 352/10000
Epoch 1/1 Batch 12/312 Trining Loss: 2.593984007835388 Step: 384/10000
Epoch 1/1 Batch 13/312 Trining Loss: 2.568935430966891 Step: 416/10000
Epoch 1/1 Batch 14/312 Trining Loss: 2.540286728313991 Step: 448/10000
Epoch 1/1

In [37]:
from transformers import PreTrainedTokenizer
def predict_batch(model, src_sequences, src_msks, tokenizer:PreTrainedTokenizer,
                   max_len=148, device="cpu"):

    

    model.eval()
    with torch.no_grad():
        strt_t = model.bart.config.decoder_start_token_id

        tgt = torch.full((src_sequences.size(0),1), strt_t, device=device)    

        # Generate tokens one by one
        for i in range(1,max_len):
            tgt_msks = torch.full(tgt.shape, 0, device=device)

            m = {
                "input_ids": src_sequences,
                "attention_mask" : src_msks,
                "decoder_input_ids" : tgt,
                "decoder_attention_mask" : tgt_msks,
                "head_mask" : None,
                "decoder_head_mask" : None,
                "cross_attn_head_mask" : None,
                "encoder_outputs" : None,
                "past_key_values" : None,
                "inputs_embeds" : None,
                "decoder_inputs_embeds" : None,
                "use_cache" : None,
                "output_attentions" : None,
                "output_hidden_states" : None,
                "return_dict" : None,
            }
            out = model(m)

            # Get next token prediction
            next_token = out.argmax(dim=-1)
            next_token = next_token[:, -1]
            next_token = next_token.unsqueeze(1)
            next_token = outvoc2voc(next_token)
            # store new prediction for active sequences
            tgt = torch.cat((tgt, next_token), dim=1)

    return tgt

# def predict_batch(model, src_sequences, src_msks, tokenizer:PreTrainedTokenizer,
#                    max_len=148, device="cpu"):
#     """
#     Generates predictions for a batch of source sequences using the given model.
#     Args:
#         model (torch.nn.Module): The model used for generating predictions.
#         src_sequence (torch.Tensor): The source sequences to be translated.
#         vocab (Vocabulary): The vocabulary object containing all tokens.
#         max_len (int, optional): The maximum length of the generated sequences. Defaults to 128.
#         device (str, optional): The device to run the model on ("cpu" or "cuda"). Defaults to "cpu".
#     Returns:
#         torch.Tensor: The generated target sequences.
#     """

#     eos = tokenizer.convert_tokens_to_ids(['<s\>'])[0]
#     bos = tokenizer.convert_tokens_to_ids(['<s>'])[0]

#     model.eval()
#     with torch.no_grad():
#         # Initialize target sequence with SOS token
#         #tgt = [vocab.sos_idx] + [vocab.pad_idx] * (max_len)
#         tgt = [bos] + ([tokenizer.pad_token_type_id] * (max_len))
#         tgt = torch.tensor([tgt], device=device)
#         tgt = tgt.repeat(src_sequences.size(0), 1)

#         tgt_msks = [1] + ([tokenizer.pad_token_type_id])*max_len
#         tgt_msks = torch.tensor([tgt_msks], device=device)
#         tgt_msks = tgt_msks.repeat(src_sequences.size(0), 1)

#         # holds indx of sequences containing EOS token
#         finished = torch.tensor([False]*src_sequences.size(0), device=device)
        
#         # Generate tokens one by one
#         for i in range(1,max_len):
#             tgt_msks[:,i] = 1
#             # indx of batch dim where EOS token has not been generated
#             active_indxs = torch.where(finished == False)[0]
#             # feed only unfinished sequences to decoder
#             # remove padding
#             m = {
#                 "input_ids": src_sequences[active_indxs],
#                 "attention_mask" : src_msks[active_indxs],
#                 "decoder_input_ids" : tgt[active_indxs],
#                 "decoder_attention_mask" : tgt_msks[active_indxs],
#                 "head_mask" : None,
#                 "decoder_head_mask" : None,
#                 "cross_attn_head_mask" : None,
#                 "encoder_outputs" : None,
#                 "past_key_values" : None,
#                 "inputs_embeds" : None,
#                 "decoder_inputs_embeds" : None,
#                 "use_cache" : None,
#                 "output_attentions" : None,
#                 "output_hidden_states" : None,
#                 "return_dict" : None,
#             }
#             out = model(m)

#             # Get next token prediction
#             next_token = out.argmax(dim=-1)
#             next_token = next_token[:, -1]
#             # store new prediction for active sequences
#             tgt[active_indxs,i] = next_token
            
#             # update finished sequences if any EOS token is generated
#             new_finished = torch.where(next_token == eos, True, False)
#             finished[active_indxs] = torch.logical_or(finished[active_indxs], new_finished)

#             # early stopping if all sequences produced EOS token
#             if finished.all():
#                 break
#     return tgt


In [39]:
inputs,inputs_msks, dec_in, dec_in_msks, targets, targets_msks = test_dataset[2]
inputs = inputs.repeat(2,1).to(device)
inputs_msks = inputs_msks.repeat(2,1).to(device)
print(inputs.shape)
pred = predict_batch(model, inputs, inputs_msks, tokenizer, device=device)
print(tokenizer.decode(targets, skip_special_tokens=True))
tokenizer.decode(pred[0].tolist(), skip_special_tokens=False)


torch.Size([2, 120])
I_TURN_RIGHT I_LOOK I_TURN_RIGHT I_LOOK I_TURN_RIGHT I_LOOK I_TURN_RIGHT I_LOOK I_TURN_RIGHT I_LOOK I_TURN_RIGHT I_LOOK I_TURN_RIGHT I_LOOK I_TURN_RIGHT I_LOOK I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT


'</s>II_______________________ IT IT_ IT IT_ IT IT_ IT_ IT IT_ IT_ IT_ IT_ I_______________________</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>'

tensor([[2., 0., 0., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
         2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
         2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
         2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
         2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
         2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
         2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
         2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
         2., 2., 2., 2.],
        [2., 0., 0., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
         2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
         2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
         2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
  

In [20]:
m = {
                "input_ids": inputs,
                "attention_mask" : inputs_msks,
                "decoder_input_ids" : None,
                "decoder_attention_mask" : None,
                "head_mask" : None,
                "decoder_head_mask" : None,
                "cross_attn_head_mask" : None,
                "encoder_outputs" : None,
                "past_key_values" : None,
                "inputs_embeds" : None,
                "decoder_inputs_embeds" : None,
                "use_cache" : None,
                "output_attentions" : None,
                "output_hidden_states" : None,
                "return_dict" : None,
            }
out = model(m)
#PRED
tokenizer.decode(out.argmax(dim=-1)[0].tolist(), skip_special_tokens=False)

'II I_ I_ I_ I<s>_</s></s></s></s>TTURNURNURNRRRRURN<s><s>URNURNURNURNURN<s><s>URNURNURNURNURNURN__URNUMPUMPOOKURN_______URNOOK I_________OOKT_________TT__<s><s></s>___<s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s>'

In [14]:
rand = torch.randint(0, len(train_dataset), (1,))

rand.to(torch.long)

tensor([10457])

In [15]:
inputs

tensor([[    0, 13724,  5483,   235, 10161,  2463,    71,   356,   198,   314,
             2,     2,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [    0, 13724,  5483,   235, 10161,  2463, 

In [16]:
targets

tensor([    0,   100,  1215,   565, 28267,  1215,  3850, 11615,    38,  1215,
          574, 23775,    38,  1215,   565, 28267,  1215,  3850, 11615,    38,
         1215,   574, 23775,    38,  1215,   565, 28267,  1215,  3850, 11615,
           38,  1215,   574, 23775,    38,  1215,   565, 28267,  1215,  3850,
        11615,    38,  1215,   574, 23775,    38,  1215,   565, 28267,  1215,
          500,  8167,    38,  1215,   565, 28267,  1215,   500,  8167,    38,
         1215,   574, 23775,    38,  1215,   565, 28267,  1215,   500,  8167,
           38,  1215,   565, 28267,  1215,   500,  8167,    38,  1215,   574,
        23775,    38,  1215,   565, 28267,  1215,   500,  8167,    38,  1215,
          565, 28267,  1215,   500,  8167,    38,  1215,   574, 23775,     2,
            2,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1])

In [17]:
special = tokenizer.all_special_tokens
for s in special:
    print(s, tokenizer.convert_tokens_to_ids(s))

<s> 0
</s> 2
<unk> 3
<pad> 1
<mask> 50264


In [18]:
#torch.save(model.state_dict(),"tbag/bart_1_1_5k.pth")

In [26]:
a = {1,3,2}
sorted(a)
a.add(1)
a

{1, 2, 3}

In [36]:
a = torch.randint(0, 100, (5,10))

a = torch.sort(torch.unique(a.flatten())).values
a

tensor([ 0,  1,  4,  5,  6,  8, 10, 13, 15, 16, 19, 21, 24, 27, 29, 30, 32, 33,
        34, 36, 37, 42, 47, 52, 53, 54, 55, 57, 61, 63, 66, 68, 76, 77, 78, 81,
        84, 89, 90, 93])

In [38]:
b = torch.full((5,100), 1)

#a = torch.tensor([1,2,3,4,5,6,7,8,9,10])
a.shape

# b.shape
#b[:,a].shape


torch.Size([40])

In [25]:
output_vocab = torch.concatenate([train_dataset.output_vocab,test_dataset.output_vocab])
output_vocab = torch.sort(torch.unique(output_vocab))
output_vocab






{0: 0,
 1: 1,
 2: 2,
 3: 8,
 4: 71,
 5: 198,
 6: 235,
 7: 314,
 8: 356,
 9: 422,
 10: 1004,
 11: 1656,
 12: 2330,
 13: 2463,
 14: 2962,
 15: 3704,
 16: 5483,
 17: 10097,
 18: 10161,
 19: 13724,
 20: 15922,
 21: 43750}