In [None]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'


In [None]:
from transformers import GPT2LMHeadModel, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")

# model = BloomForCausalLM.from_pretrained("bigscience/bloom-560m-intermediate", revision='global_step10000')


In [None]:
import os

# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

tokenizer.pad_token = tokenizer.eos_token

from torch import nn
import torch


def compute_loss_labelsmoothed(logits, labels, ignore_index=-100, epsilon=0.1):
    logits = logits[..., :-1, :].contiguous()
    labels = labels[..., 1:].contiguous()

    log_probs = -nn.functional.log_softmax(logits, dim=-1)
    if labels.dim() == log_probs.dim() - 1:
        labels = labels.unsqueeze(-1)

    padding_mask = labels.eq(ignore_index)

    labels = torch.clamp(labels, min=0)
    nll_loss = log_probs.gather(dim=-1, index=labels)

    smoothed_loss = log_probs.sum(dim=-1, keepdim=True, dtype=torch.float32)

    nll_loss.masked_fill_(padding_mask, 0.0)
    smoothed_loss.masked_fill_(padding_mask, 0.0)

    # Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
    num_active_elements = padding_mask.numel() - padding_mask.long().sum()
    nll_loss = nll_loss.sum() / num_active_elements
    smoothed_loss = smoothed_loss.sum() / (num_active_elements * log_probs.shape[-1])
    return (1 - epsilon) * nll_loss + epsilon * smoothed_loss

    
class MyNet(nn.Module): #t5 сложнее создавать создавать датасет для лм так как много пришлось бы возиться с префиксами
    def __init__(self):
        super().__init__()
        self.transformer = GPT2LMHeadModel.from_pretrained('gpt2')
#         self.transformer = BloomForCausalLM.from_pretrained(
#             "bigscience/bloom-560m-intermediate",
#             revision=revision,
#         )

#         self.transformer.resize_token_embeddings(self.transformer.config.vocab_size + 30)
        
        hid_size = self.transformer.config.hidden_size
        self.voc_size = self.transformer.config.vocab_size
        
        self.early_exits = nn.ModuleList([
            nn.Linear(hid_size, self.voc_size) for _ in range((self.transformer.config.n_layer - 1) // 4)
        ])
        
        self.ce = nn.CrossEntropyLoss()
    
    def forward(self, input_ids, attention_mask, labels=None):
        output = self.transformer(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            output_hidden_states=True,
        )
        
        hidden_states = output.hidden_states[1:-1]
        # hidden_states[idx] -- [bs, seqlen, hid_dim]
        
        heads_outputs = [
            self.early_exits[idx](hidden_states[idx]) 
            for idx in range(len(self.early_exits))
            if (idx + 1) % 3 == 0
        ]
        print(len(heads_outputs))
                
        if labels is None:
            heads_outputs = [
                torch.softmax(head_output, dim=-1) for head_output in heads_outputs
            ] # h_os[i][bs][seqlen][tok_num] = P(из i-го слоя на seqlen месте стоит токен tok_num)

            return {'head_outputs': heads_outputs, 'last_head': torch.softmax(output.logits, dim=-1)}
        
        # loss = output.loss
        # heads_outputs[num_layers, bs, seqlen, num_tokens]
        # labels[bs * seqlen]
        
        
        losses = [
            compute_loss_labelsmoothed(head_output, labels)
            for head_output in heads_outputs
        ] # [num_layers, ]
        
        losses = torch.stack(losses,)
        total_loss = torch.sum(losses)
        
        heads_outputs = [
            torch.softmax(head_output, dim=-1) for head_output in heads_outputs
        ] # h_os[i][bs][seqlen][tok_num] = P(из i-го слоя на seqlen месте стоит токен tok_num)

        return {
            'loss': total_loss, 
            'head_outputs': heads_outputs,  # [num_layers, bs=1, seq_len, vocab_size] # {token: {layer_num: [probabilities, ... ]}}
            'last_head': torch.softmax(output.logits, dim=-1),
        }
        

In [None]:
import torch

In [None]:
#имею словарь с вер-ями предсказать правильный токен
#дальше выбираю слова-токены и смотрю когда они хорошо предсказыввались а когда нет

In [None]:
from transformers import DataCollatorForLanguageModeling

collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

net = MyNet().to('cuda')
net = net.eval()

In [None]:
dataset_path = 'files/dataset'
dataset_cache = 'files/.cache'

In [None]:
# from datasets import load_dataset
# import numpy as np

# dataset = load_dataset("wikipedia", "20220301.en", cache_dir=dataset_cache)

# rand_idx = np.random.choice(np.arange(len(dataset['train'])), size=500_000, replace=False)

# # import json
# # rand_idx = json.load(open('indices.json', 'r'))

# dataset = dataset['train'].select(rand_idx, )

# # import json

# # json.dump(rand_idx.tolist(), open('indices.json', 'w'),)

# def tokenize_data(example):
#     return tokenizer(example['text'], max_length=512, truncation=True)

# dataset = dataset.map(
#     tokenize_data, remove_columns=['text', 'id', 'url', 'title'], batched=True, num_proc=10
# )

# dataset.save_to_disk(dataset_path)


from datasets import load_from_disk

dataset = load_from_disk(dataset_path)

In [None]:
for n, p in net.named_parameters():
    if 'transformer' in n:
        p.requires_grad_(False)

In [None]:
from transformers import Trainer, TrainingArguments

trainer = Trainer(
    model=net,
    args=TrainingArguments(
        'logs/gpt2',
        num_train_epochs=1,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=2,
        save_total_limit=2, 
        save_steps=1000,
        fp16=True,
        logging_steps=100,
    ),
    train_dataset=dataset,
    data_collator=collator,
    tokenizer=tokenizer,
)

In [None]:
trainer.train()