In [1]:
import torch
import wandb
from datetime import datetime
import shutil
import importlib
import os
from importlib.machinery import SourceFileLoader
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
import random
from torch.nn.utils.rnn import pad_sequence
from torch import nn

In [2]:
from transformers import (
    MODEL_WITH_LM_HEAD_MAPPING,
    WEIGHTS_NAME,
    AdamW,
    AutoConfig,
    AutoModelWithLMHead,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
    get_linear_schedule_with_warmup,
)

In [3]:
SPECIAL_TOKENS = [
	'<speaker_1>', 
	'</speaker_1>', 
	
	'<speaker_2>',
	'</speaker_2>',
	
	'<persona>',
	'</persona>'
]

SPECIAL_TOKENS = {item:item for item in SPECIAL_TOKENS}

In [9]:
class BaseExperiment:
    def __init__(self, 
        model=None, 
        dataloader_train=None,
        dataloader_valid=None,
        dataloader_test=None,
        loss_func_class=None,
        estimate_func_class=None,
        experiment_config=None,
        optimizer_class=None,
        sheduler_class=None,
        project_name=None,
        notebook_name=None,
        name_run="",
        model_description=""
        ): 
        assert notebook_name != None, f"notebook_name should be valid filename, but get {notebook_name}"

        # datasets
        self.dataloader_train = dataloader_train
        self.dataloader_valid = dataloader_valid
        self.dataloader_test = dataloader_test
        
        # wandb
        self.notebook_name = notebook_name
        self.project_name = project_name 
        self.experiment_config = experiment_config
        self.wandb_run = None
        self.name_run = name_run
        self.model_description = model_description
        self.model_name = "pytorch_model"
        self.model_artifact = None

        self.optimizer_class = optimizer_class
        self.sheduler_class = sheduler_class
        self.loss_func_class = loss_func_class
        self.estimate_func_class = estimate_func_class

        self.model = model
        self.optimizer = None
        self.sheduler = None
        self.loss_func = None
        self.estimate_func = None
        # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.device = torch.device('cpu')
        print(f"Using device {self.device}")

        # prepare for experiment
        self.setup()
        self.unit_tests()

    def setup(self):
        self.model.to(self.device)
        self.optimizer = self.optimizer_class(self.model.parameters(), **self.experiment_config['optimizer'])
        
        if self.sheduler_class != None:
            self.sheduler = self.sheduler_class(self.optimizer, **self.experiment_config['sheduler'])

        self.loss_func = self.loss_func_class()
        self.estimate_func = self.estimate_func_class()

        # set model name
        date_time = self.get_date()
        self.model_name = f"{self.name_run}---{date_time}.pt"
        self.experiment_config['model_name'] = self.model_name

        # setup wandb
        # save model structure and weights to wandb
        self.model_artifact = wandb.Artifact(
            self.name_run, type="model",
            description=self.model_description,
            metadata=self.experiment_config)


    def get_date(self):
        now = datetime.now()
        date_time = now.strftime("%m_%d_%Y__%H:%M:%S")
        return date_time

    def unit_tests(self):
        # test training
        X, y = next(iter(self.dataloader_train))
        X, y = X.to(self.device), y.to(self.device)

        pred = self.model(X)
        pred = pred[..., :-1, :].contiguous().view(-1, pred.size(-1))
        y = y[..., 1:].contiguous().view(-1)

        loss = self.loss_func(pred, y)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # test valid
        X, y = next(iter(self.dataloader_valid))
        X, y = X.to(self.device), y.to(self.device)
        pred = self.model(X)
        pred = pred[..., :-1, :].contiguous().view(-1, pred.size(-1))
        y = y[..., 1:].contiguous().view(-1)
        test_loss = self.estimate_func(pred, y).item()
        # correct = (pred.argmax(1) == y).type(torch.float).sum().item()

        # initial validation
        self.model.eval()
        test_loss, correct = 0, 0
        num_batches = len(self.dataloader_valid)
        size = len(self.dataloader_valid.dataset)

        with torch.no_grad():
            for X, y in self.dataloader_valid:
                X, y = X.to(self.device), y.to(self.device)
                pred = self.model(X)
                pred = pred[..., :-1, :].contiguous().view(-1, pred.size(-1))
                y = y[..., 1:].contiguous().view(-1)
                test_loss += self.estimate_func(pred, y).item()
                # correct += (pred.argmax(1) == y).type(torch.float).sum().item()

        test_loss /= num_batches
        # correct /= size
        print("Initial val = ", test_loss)

        print("tests ok")


    def train(self):
        # https://colab.research.google.com/github/wandb/examples/blob/master/colabs/wandb-artifacts/Pipeline_Versioning_with_W%26B_Artifacts.ipynb#scrollTo=qrAWbBV1rd4I
        # если попытаться создать переменную чтобы не городить тут код возникает ошибка с wandb!
        with wandb.init(project=self.project_name, entity="dimweb",
                        settings=wandb.Settings(
                            start_method="thread", 
                            # symlink=False
                            ),
                        reinit=True,
                        name=self.name_run,
                        config=self.experiment_config,
                        # sync_tensorboard=True
                        ) as run:

            self.run = run
            
            # save model class
            self.save_model_class()

            # start train
            epochs = self.experiment_config['epochs']
            for i in range(epochs):
                print(f"Epoch: {i}")
                self.train_steps()
                self.valid_steps()
            
            # sync model
            self.wandb_save_model()
            
            print(f"train end")
    
    def save_model_class(self):
        # save class
        model_class_name = self.experiment_config['model_class_name']
        class_script_path_dest = f"{os.path.join(wandb.run.dir, model_class_name)}.py"
        class_script_path_src = f"./models/{model_class_name}.py"
        shutil.copy2(class_script_path_src, class_script_path_dest)
        self.model_artifact.add_file(class_script_path_dest)
        wandb.save(class_script_path_dest)

    def wandb_save_model(self):
        # wandb использует symlinks для того чтобы сохранять файлы
        # но из-за проблем с правами доступа возникает ошибка и модель нельзя сохранить
        # поэтому пришлось сохранять модель в дирректорию с самим запуском
        # https://docs.wandb.ai/guides/track/advanced/save-restore#example-of-saving-a-file-to-the-wandb-run-directory
        model_save_path = os.path.join(wandb.run.dir, self.model_name)
        torch.save(self.model.state_dict(), model_save_path)
        self.model_artifact.add_file(model_save_path)
        wandb.save(model_save_path)

        # save notebook
        notebook_path = os.path.join(wandb.run.dir, self.notebook_name)
        shutil.copy2(self.notebook_name, notebook_path)
        self.model_artifact.add_file(notebook_path)
        wandb.save(notebook_path)

        wandb.log_artifact(self.model_artifact)
    
    def train_steps(self):
        raise NotImplementedError("You need specify training steps")

    def valid_steps(self):
        raise NotImplementedError("You need specify valid steps")

    def load_model(self, artifact_name, additional_model_args={}):
        assert artifact_name != ""
        with wandb.init(project=self.project_name, job_type="inference"):
            model_artifact = wandb.use_artifact(artifact_name)
            model_dir = model_artifact.download()
            model_config = model_artifact.metadata
            model_path = os.path.join(model_dir, model_config['model_name'])
            # print(model_config)

            model_class_name = model_config['model_class_name']
            model_script_path = f"./artifacts/{artifact_name}/{model_class_name}.py"
            # get module by path https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path?rq=1 
            model_class = getattr(SourceFileLoader(model_class_name, model_script_path).load_module(), model_class_name)
            
            model_args = model_config['model_args']
            model = model_class(**model_args, **additional_model_args)

            model.load_state_dict(torch.load(model_path))
            self.model = model
            self.model.to(self.device)
    
    @staticmethod
    def static_load_model(artifact_name="", project_name="", additional_model_args={}):
        assert artifact_name != ""
        assert project_name != ""
        with wandb.init(project=project_name, job_type="inference"):
            model_artifact = wandb.use_artifact(artifact_name)
            model_dir = model_artifact.download()
            model_config = model_artifact.metadata
            model_path = os.path.join(model_dir, model_config['model_name'])

            model_class_name = model_config['model_class_name']
            model_script_path = f"./artifacts/{artifact_name}/{model_class_name}.py" 
            model_class = getattr(SourceFileLoader(model_class_name, model_script_path).load_module(), model_class_name)
            
            model_args = model_config['model_args']
            model = model_class(**model_args, **additional_model_args)
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            model.to(device)
            model.load_state_dict(torch.load(model_path))
            
            return model

    def test(self, artifact_name="", model=None):
        raise NotImplementedError("You need specify test steps")


class Experiment(BaseExperiment):
    def __init__(self, **kwargs): 
        super(Experiment, self).__init__(**kwargs)
    
    def train_steps(self):
        self.model.train()
        interval = self.experiment_config['check_interval']

        for batch, (X, y) in enumerate(self.dataloader_train):
            # Send data to training device
            X, y = X.to(self.device), y.to(self.device)
            
            # Compute prediction error
            pred = self.model(X)
            loss = self.loss_func(pred, y)
            
            # Backpropagation
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            if self.sheduler != None:
                self.sheduler.step()
            
            # Progress output
            if batch % interval == 0:
                wandb.log({"train_loss": loss.item()})

    def valid_steps(self):
        self.model.eval()
        test_loss, correct = 0, 0
        num_batches = len(self.dataloader_valid)
        size = len(self.dataloader_valid.dataset)

        with torch.no_grad():
            for X, y in self.dataloader_valid:
                X, y = X.to(self.device), y.to(self.device)
                pred = self.model(X)
                test_loss += self.estimate_func(pred, y).item()
                correct += (pred.argmax(1) == y).type(torch.float).sum().item()

        test_loss /= num_batches
        correct /= size
        
        wandb.log({"val_loss": test_loss})
        wandb.log({"val_acc": correct})

    def test(self, artifact_name="", model=None):
        if model is None:
            self.load_model(artifact_name)
        else:
            self.model = model
            self.model.to(self.device)

        print("model loaded to disk")
        predictions = []
        
        self.model.eval()

        with torch.no_grad():
            for X, _ in self.test_dataloader:
                X = X.to(self.device)
                pred = self.model(X).argmax(1).cpu().numpy()
                predictions.extend(list(pred))
        
        date_time = self.get_date()
        filename = f"./predictions/{self.name_run}---{date_time}.csv"
        with open(filename, 'w') as solution:
            print('Id,Category', file=solution)
            for i, label in enumerate(predictions):
                print(f'{i},{label}', file=solution)
        print("test end")

## Создаем датасет

In [5]:
persona_chat_original = pd.read_csv("./persona_chat.csv")
persona_chat_original = persona_chat_original[:1000]
persona_chat_original.head(3)

Unnamed: 0.1,Unnamed: 0,Persona,chat
0,0,i like to remodel homes. i like to go hunting...,"hi , how are you doing ? i am getting ready to..."
1,1,my mom is my best friend. i have four sisters...,"hi , how are you doing today ?\ni am spending ..."
2,2,i had a gig at local theater last night. i wo...,"we all live in a yellow submarine , a yellow s..."


In [6]:
class PersonaChatGenerator:
	def __init__(self, 
		initial_dataset=None,
	):
		self.initial_dataset = initial_dataset
		self.processed_dataset = []
		self.process_dataset()

	def process_dataset(self):
		processed_dataset = {
			"persona": [],
			"history": [],
			# "target": []
		}

		speaker_1_start = SPECIAL_TOKENS['<speaker_1>']
		speaker_1_end = SPECIAL_TOKENS['</speaker_1>']
		
		speaker_2_start = SPECIAL_TOKENS['<speaker_2>']
		speaker_2_end = SPECIAL_TOKENS['</speaker_2>']

		for i in range(len(self.initial_dataset)):
			persona = self.initial_dataset['Persona'].iloc[i]
			persona = f"{SPECIAL_TOKENS['<persona>']} {persona} {SPECIAL_TOKENS['</persona>']}"
			chat = self.initial_dataset['chat'].iloc[i].split("\n")
			chat = chat[:-1]
			history = ""
			for j in range(len(chat)):
				reply = chat[j]
				if (j+1) % 2 == 0:
					reply = f"{speaker_2_start} {reply} {speaker_2_end}"
					history += reply

					processed_dataset['persona'].append(persona)
					processed_dataset['history'].append(history)
					# processed_dataset['target'].append(reply)

				else:
					reply = f"{speaker_1_start} {reply} {speaker_1_end}"
					history += reply 

		dataset = pd.DataFrame(data=processed_dataset)
		return dataset

train_dataset_csv, valid_dataset_csv = train_test_split(persona_chat_original, test_size=0.1)
train_dataset_csv, valid_dataset_csv = train_dataset_csv.reset_index(), valid_dataset_csv.reset_index()

train_dataset_generator = PersonaChatGenerator(
	initial_dataset=train_dataset_csv,
)

valid_dataset_generator = PersonaChatGenerator(
	initial_dataset=valid_dataset_csv,
)

In [7]:
class PersonaChatDataset(Dataset):
	def __init__(self, 
		initial_dataset=None,
		tokenizer=None
	):
		self.initial_dataset = initial_dataset
		self.tokenizer = tokenizer
	
	def __len__(self):
		return len(self.initial_dataset)
	
	def __getitem__(self, idx):
		row = self.initial_dataset.iloc[idx]
		persona = [item.strip() for item in row['persona'].split(".") if len(item) > 0 ]
		random.shuffle(persona)
		persona = [torch.tensor(self.tokenizer.encode(item)).flatten() for item in persona]
		persona = torch.cat([*persona])

		history = row['history']
		history = self.tokenizer.encode(history)
		history = torch.tensor(history).flatten()
		
		# target = row['target']
		# target = torch.tensor(self.tokenizer.encode(target)).flatten()
		# target = torch.cat([target, torch.tensor([self.tokenizer.eos_token_id], dtype=torch.long)])

		# feature = torch.cat([persona, history, torch.tensor([tokenizer.eos_token_id])])
		feature = torch.cat([persona, history])

		# target = torch.tensor(target, dtype=torch.long)
		# target = target.flatten()
		return {
			"feature": feature,
			"target": feature 
		}


tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
tokenizer.add_tokens(list(SPECIAL_TOKENS.values()), special_tokens=True)

train_dataset = PersonaChatDataset(
	initial_dataset=train_dataset_generator.process_dataset(),
	tokenizer=tokenizer
)

valid_dataset = PersonaChatDataset(
	initial_dataset=valid_dataset_generator.process_dataset(),
	tokenizer=tokenizer
)

def collate(examples):
	# print("EXAMPLES", examples)
	features = [item['feature'] for item in examples]
	features = pad_sequence(features, batch_first=True)
	
	targets = [item['target'] for item in examples]
	targets = pad_sequence(targets, batch_first=True)

	# return {
	# 	"feature": torch.tensor(features, dtype=torch.long),
	# 	"target": torch.tensor(targets, dtype=torch.long)
	# }
	return torch.tensor(features, dtype=torch.long), torch.tensor(features, dtype=torch.long)

train_dataloader = DataLoader(
    train_dataset, 
	batch_size=4, 
	collate_fn=collate, 
	drop_last = True,
	shuffle=True
)

valid_dataloader = DataLoader(
    valid_dataset, 
	batch_size=8, 
	collate_fn=collate, 
	drop_last=False,
	shuffle=False
)

## Тестируем dumb модель

In [10]:
# google colab не обновляет файлы, поэтому приходится делать это вручную, при помощи такого страшного импорта
def import_class(class_name):
    return getattr(importlib.reload(getattr(__import__(f"models.{class_name}"), class_name)), class_name)

TestModel = import_class("TestModel") 


exp_config = {
    "batch_size": 64,
    "check_interval": 100,
    "epochs": 1,
    "optimizer": {
        "lr": 1e-3
    },
    "model_name": "pytorch_model",
    "model_class_name": str(TestModel.__name__),
    "model_args": {
        "n_classes": len(tokenizer)
    }
}

exp_config["sheduler"] = {
    "max_lr": 0.01, 
    "steps_per_epoch": len(train_dataloader), 
    "epochs": 10
}

model = TestModel(**exp_config['model_args'])
model.test()

# не хочу создавать глобальные переменные 
exp_params = {
    "model": model, 
    "dataloader_train": train_dataloader,
    "dataloader_valid": valid_dataloader,
    "dataloader_test": valid_dataloader,
    "loss_func_class": nn.CrossEntropyLoss,
    "estimate_func_class": nn.CrossEntropyLoss,
    "experiment_config": exp_config,
    "optimizer_class": torch.optim.Adam,
    "sheduler_class": None,
    "notebook_name": "gpt_persona_v1.ipynb",
    "project_name": "gpt_persona_bot",
    "name_run": "test_run",
    "model_description": "Test my new model",
}

experiment_test = Experiment(**exp_params)

testt1
Using device cpu


  return torch.tensor(features, dtype=torch.long), torch.tensor(features, dtype=torch.long)


Initial val =  10.83613044984879
tests ok


In [40]:
outputs = torch.rand(
	(4, 123, 32),
	dtype=torch.float32,
)
# lin 
mat = torch.rand((32, 5623))
a = torch.matmul(outputs, mat)

In [64]:
from transformers import GPT2Tokenizer, GPT2Model, GPTLM
model = GPT2Model.from_pretrained("microsoft/DialoGPT-small")
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
tokenizer.add_tokens(list(SPECIAL_TOKENS.values()), special_tokens=True)
model.resize_token_embeddings(len(tokenizer))

Some weights of the model checkpoint at microsoft/DialoGPT-small were not used when initializing GPT2Model: ['lm_head.weight']
- This IS expected if you are initializing GPT2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Embedding(50263, 768)

In [98]:
lin = nn.Linear(32, len(tokenizer))
x = torch.randint(0, 1, (4, 208), dtype=torch.long)

batch_size, sequence_length = x.shape[:2]
hidden_states = torch.rand(
	(batch_size, sequence_length, 32),
	dtype=torch.float32,
)
logits = lin(hidden_states)
print(logits.shape, x.shape)
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = x[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
	shift_logits.view(-1, shift_logits.size(-1)), 
	shift_labels.view(-1))

loss

torch.Size([4, 208, 50263]) torch.Size([4, 208])


tensor(11.1798, grad_fn=<NllLossBackward0>)

In [80]:
from torch import nn
criteria = nn.CrossEntropyLoss()
input = torch.tensor([
	[[3.4, 1], [1.5, 1],[0.4, 1], [0.10, 1]],
	[[3.4, 1], [1.5, 1],[0.4, 1], [0.10, 1]],
	[[3.4, 1], [1.5, 1],[0.4, 1], [0.10, 1]],
],dtype=torch.float)
target = torch.tensor([
	[0, 0],
	[0, 0],
	[0, 0],
], dtype=torch.long)
print(input.shape, target.shape)
criteria(input, target)

torch.Size([3, 4, 2]) torch.Size([3, 2])


tensor(0.7992)