In [1]:
import pandas as pd
# from tokenizers.implementations import BertWordPieceTokenizer
from transformers import PreTrainedTokenizerFast
import torch
import torch.nn as nn
import numpy as np
from itertools import chain
from collections import Counter
import gc
from torch.utils.data import DataLoader, Dataset
import mlflow
import torch.optim as optim
import mlflow.pytorch
from torchmetrics import F1Score as F1
from tqdm import tqdm

train_df = pd.read_parquet("task3_recsys/df_for_transformer_TRAIN.parquet")
test_df = pd.read_parquet("task3_recsys/df_for_transformer_TEST.parquet")
tokenizer = PreTrainedTokenizerFast(tokenizer_file="task3_recsys/models/tokenizer/vocab_256_bpe.json") # наиболее адекватное распределение по сравнению с большим кол-вом токенов
special_tokens = {"mask_token": "[MASK]", "cls_token": "[CLS]", "sep_token": "[SEP]", "pad_token": "[PAD]"}
tokenizer.add_special_tokens(special_tokens)


4

In [2]:
train_df.head(5)

Unnamed: 0,app_id,mcc_cat_list,amnt_list_discrete,next_mcc
0,878469,bbbbbbgbbeudnvbgfgffxxddvdffnaejjeajbgdjedqqal...,"[8, 8, 8, 8, 8, 8, 2, 8, 8, 7, 8, 6, 5, 8, 8, ...",a
1,673064,bbbbbb,"[6, 6, 6, 5, 6, 6]",b
2,978962,bbbbbbbadaadbeeaaabbbbbaeadaaabaavbbk,"[8, 7, 6, 6, 6, 7, 4, 7, 5, 7, 4, 6, 6, 6, 6, ...",h
3,596073,bbbbbbbbggikieokfggaaaagaxfabafaaaabgggagbaaag...,"[7, 7, 7, 5, 6, 6, 7, 8, 7, 3, 6, 8, 7, 4, 6, ...",j
4,528044,cbaabbbdcbcbibdccbcbbbbbsidgfccdbccbdgabbbdida...,"[5, 8, 7, 6, 4, 1, 7, 2, 2, 4, 6, 6, 7, 8, 5, ...",c


In [3]:
vectorized_mcc_train0 = tokenizer(list(train_df.iloc[:250000].mcc_cat_list.values), return_tensors="pt", max_length=256, 
                           padding="max_length", truncation=True, padding_side="left", return_token_type_ids=False)
vectorized_mcc_train0, vectorized_mcc_train0_attnmask = vectorized_mcc_train0.input_ids.to(torch.int16), vectorized_mcc_train0.attention_mask.to(torch.bool)
gc.collect()

vectorized_mcc_train1 = tokenizer(list(train_df.iloc[250000:-250000].mcc_cat_list.values), return_tensors="pt", max_length=256, 
                           padding="max_length", truncation=True, padding_side="left", return_token_type_ids=False)
vectorized_mcc_train1, vectorized_mcc_train1_attnmask = vectorized_mcc_train1.input_ids.to(torch.int16), vectorized_mcc_train1.attention_mask.to(torch.bool)
gc.collect()

vectorized_mcc_train2 = tokenizer(list(train_df.iloc[-250000:].mcc_cat_list.values), return_tensors="pt", max_length=256, 
                           padding="max_length", truncation=True, padding_side="left", return_token_type_ids=False)
vectorized_mcc_train2, vectorized_mcc_train2_attnmask = vectorized_mcc_train2.input_ids.to(torch.int16), vectorized_mcc_train2.attention_mask.to(torch.bool)
gc.collect()

vectorized_mcc_test = tokenizer(list(test_df.mcc_cat_list.values), return_tensors="pt", max_length=256, 
                           padding="max_length", truncation=True, padding_side="left")
vectorized_mcc_test, vectorized_mcc_test_attnmask = vectorized_mcc_test.input_ids.to(torch.int16), vectorized_mcc_test.attention_mask.to(torch.bool)

In [4]:
vectorized_mcc_train = torch.concatenate([vectorized_mcc_train0, vectorized_mcc_train1, vectorized_mcc_train2], dim=0)
vectorized_mcc_train_attnmask = torch.concatenate([vectorized_mcc_train0_attnmask, vectorized_mcc_train1_attnmask, vectorized_mcc_train2_attnmask], dim=0)

assert vectorized_mcc_train.shape == vectorized_mcc_train_attnmask.shape

Построим Encoder (BERT-like) модель на задачу MLM. Т.к. нас по логике чуть больше интересуют токены ближе к концу (последние покупки) - учтем это в логике маскирования:

In [5]:
PADTOKEN = 259

def get_positional_encoding(seq_len, model_dim):
    positions = torch.arange(seq_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, model_dim, 2) * -(np.log(10000.0) / model_dim))
    positional_encodings = torch.zeros((seq_len, model_dim))
    positional_encodings[:, 0::2] = torch.sin(positions * div_term)
    positional_encodings[:, 1::2] = torch.cos(positions * div_term)
    return positional_encodings


class TransformerEncoder(nn.Module):
    def __init__(self, model_dim, num_heads, ff_dim, dropout=0.1):
        super(TransformerEncoder, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=model_dim, num_heads=num_heads, dropout=dropout)
        self.layer_norm1 = nn.LayerNorm(model_dim)
        self.ffn = nn.Sequential(
            nn.Linear(model_dim, ff_dim),
            nn.GELU(),
            nn.Linear(ff_dim, model_dim)
        )
        self.layer_norm2 = nn.LayerNorm(model_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, attn_mask):
        attn_output, _ = self.attention(query=x, key=x, value=x, key_padding_mask=attn_mask)
        x = self.layer_norm1(x + self.dropout(attn_output))
        ffn_output = self.ffn(x)
        x = self.layer_norm2(x + self.dropout(ffn_output))
        return x

def mask_input(data, mask_token=256, pad_token=PADTOKEN, mask_prob=0.2, 
               orig_mask_proba_repl=0.1, random_mask_proba_repl=0.1,
               last_tok_proba_repl=0.2):
    pads = data != pad_token
    mask = (torch.rand_like(data, dtype=torch.float) < mask_prob) & pads
    masked_data = data.clone()
    masked_data[mask] = mask_token
    mask_indices = torch.nonzero(mask, as_tuple=True)
    
    orig_mask_indices = torch.rand(mask_indices[0].size(0)) < orig_mask_proba_repl
    random_mask_indices = ~orig_mask_indices & (torch.rand(mask_indices[0].size(0)) < random_mask_proba_repl)
    
    masked_data[mask_indices[0][orig_mask_indices], mask_indices[1][orig_mask_indices]] = \
        data[mask_indices[0][orig_mask_indices], mask_indices[1][orig_mask_indices]]
    
    random_choices = data[pads]
    random_choices = random_choices[torch.randint(len(random_choices), (random_mask_indices.sum().item(),))]
    masked_data[mask_indices[0][random_mask_indices], mask_indices[1][random_mask_indices]] = random_choices

    for i in range(data.size(0)):
        row = data[i]
        valid_indices = (row != pad_token).nonzero(as_tuple=True)[0]
        if len(valid_indices) > 0:
            last_idx = valid_indices[-1].item()
            if torch.rand(1).item() < last_tok_proba_repl:
                masked_data[i, last_idx] = mask_token
                mask[i, last_idx] = True

    return masked_data.int(), mask.bool()


Поступим, как в тренировке roberta - будем делать MLM динамически каждую эпоху, а не один раз - для этого придется обновлять трейн датасет и даталоадер каждую эпоху соответственно.

In [6]:
DEVICE = "cuda"

class TrxDataset(Dataset):
    def __init__(self, 
                 vectorized_mcc_masked, vectorized_mcc_attnmask, vectorized_mcc_mlmmask, vectorized_mcc_target):

        # X/model input
        self.vectorized_mcc_masked = vectorized_mcc_masked.int().to(DEVICE)
        self.vectorized_mcc_attnmask = ~vectorized_mcc_attnmask.bool().to(DEVICE)
        self.vectorized_mcc_mlmmask = vectorized_mcc_mlmmask.bool().to(DEVICE) # на инференсе подаваться не будет
        # Y/loss calc
        self.vectorized_mcc_target = vectorized_mcc_target.int().to(DEVICE)

    def __getitem__(self, idx):
        return (self.vectorized_mcc_masked[idx], self.vectorized_mcc_attnmask[idx], self.vectorized_mcc_mlmmask[idx], self.vectorized_mcc_target[idx])
    
    def __len__(self):
        return len(self.vectorized_mcc_target)

In [7]:
vectorized_mcc_train_masked, vectorized_mcc_train_mlmmask = mask_input(vectorized_mcc_train)
vectorized_mcc_test_masked, vectorized_mcc_test_mlmmask = mask_input(vectorized_mcc_test)

trainset = TrxDataset(
    vectorized_mcc_train_masked, vectorized_mcc_train_attnmask, vectorized_mcc_train_mlmmask, vectorized_mcc_train
)
testset = TrxDataset(
    vectorized_mcc_test_masked, vectorized_mcc_test_attnmask, vectorized_mcc_test_mlmmask, vectorized_mcc_test
)

len(trainset), len(testset)

(771040, 192760)

In [8]:
class MaskedLoss(nn.Module):
    def __init__(self, alpha=0.5, class_weights=None):
        super(MaskedLoss, self).__init__()
        self.alpha = alpha
        # self.class_weights = torch.tensor(list(class_weights.values())).to(DEVICE) if class_weights else None
        self.class_weights = None
        self.criterion = nn.CrossEntropyLoss(reduction='none', weight=self.class_weights)

    def forward(self, y_pred, y_true, mask):
        loss = self.criterion(y_pred.transpose(1, 2), y_true)
        attn_mask = (y_true != PADTOKEN).float()
        if self.alpha == 0:
            mask = mask * attn_mask
            return (loss * mask).sum() / (mask.sum() + 1e-8) # lower computations
        mask, unmask = mask * attn_mask, (1 - mask) * attn_mask
        masked_loss = (loss * mask).sum() / (mask.sum() + 1e-8)
        unmasked_loss = (loss * unmask).sum() / (unmask.sum() + 1e-8)
        return (1 - self.alpha) * masked_loss + self.alpha * unmasked_loss

class MaskedAccuracy:
    @staticmethod
    def compute(y_pred, y_true, mask):
        preds = torch.argmax(y_pred, dim=-1)
        correct = (preds == y_true).float() * mask
        return correct.sum() / (mask.sum() + 1e-8)
    
class UnmaskedAccuracy:
    @staticmethod
    def compute(y_pred, y_true, mask):
        preds = torch.argmax(y_pred, dim=-1)
        correct = (preds == y_true).float() * (1 - mask) * (y_true != PADTOKEN).float()
        total = ((1 - mask) * (y_true != PADTOKEN).float()).sum() + 1e-8
        return correct.sum() / total

class TransformerModel(nn.Module):
    def __init__(self, vocab_size, model_dim, num_heads, ff_dim, dropout, max_seq_len, num_layers):
        super(TransformerModel, self).__init__()
        self.num_segments = 2
        self.embedding = nn.Embedding(vocab_size, model_dim, padding_idx=PADTOKEN)
        self.positional_encoding = get_positional_encoding(max_seq_len, model_dim)
        self.encoders = nn.ModuleList(
            [TransformerEncoder(model_dim, num_heads, ff_dim, dropout) for _ in range(num_layers)]
        )
        self.fc = nn.Linear(model_dim, vocab_size)
        # self.register_buffer("positional_encoding", get_positional_encoding(max_seq_len, model_dim))

    def forward(self, x, attn_mask):
        embeddings = self.embedding(x)
        embeddings += self.positional_encoding[:embeddings.size(1), :].to(embeddings.device)
        x = embeddings.permute(1, 0, 2)
        for enc in self.encoders:
            x = enc(x, attn_mask)
        x = self.fc(x.permute(1, 0, 2))
        return x
    
    def get_mlm_preds(self, x, mask, attn_mask):
        x = self.forward(x, attn_mask)
        return torch.argmax(x[(mask == 1)], dim=-1) 

    def get_hidden_states(self, x, attn_mask):
        embeddings = self.embedding(x)
        embeddings += self.positional_encoding[:embeddings.size(1), :].to(embeddings.device)
        x = embeddings.permute(1, 0, 2)
        for enc in self.encoders:
            x = enc(x, attn_mask)
        return x.permute(1, 0, 2)

Выполним небольшой грид-серч (насколько это позволяет инфраструктура относительно времени и производительности), логгирование будет производиться в локальный MLFlow сервер.

```mlflow ui --host 0.0.0.0 --port 5000```

In [None]:
from itertools import product

combs = list(product(
    [256, 512], #0 - dim
    [4, 8], #1 - heads
    [256, 512], #2 - ffdim
    [.05], #3 - drp
    [2, 4] #4 - num layers
))

BATCH_SIZE = 16
MAX_INPUT_LEN = 256

trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
valloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)

for comb in tqdm(combs):
    mlflow.set_tracking_uri("http://0.0.0.0:5000")
    mlflow.set_experiment("mlm_dynamic_trx_vtb")
    MODEL_DIM = comb[0]
    NUM_HEADS = comb[1]
    FF_DIM = comb[2]
    DROPOUT = comb[3]
    EPOCHS = 3
    LEARNING_RATE = 1e-3
    NUM_LAYERS = comb[4]
    ALPHA_LOSS = 0 # FOR TRUE MLM
    OTHER_NOTES = "1st train"

    mlflow.log_params({
        "model_dim": MODEL_DIM,
        "num_heads": NUM_HEADS,
        "ff_dim": FF_DIM,
        "dropout": DROPOUT,
        "epochs": EPOCHS,
        "learning_rate": LEARNING_RATE,
        "num_layers": NUM_LAYERS,
        "alpha_loss": ALPHA_LOSS,
        "other_notes": OTHER_NOTES,
    })

    trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
    valloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)

    vocab_size = len(tokenizer.get_vocab())
    model = TransformerModel(vocab_size, MODEL_DIM, NUM_HEADS, FF_DIM, DROPOUT, MAX_INPUT_LEN, NUM_LAYERS)
    model.to(DEVICE)
    torch.cuda.empty_cache()
    criterion = MaskedLoss(alpha=ALPHA_LOSS) #class_weights=weights_dict_norm)
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=LEARNING_RATE, total_steps=len(trainloader) * EPOCHS
    )
    f1_metric = F1(num_classes=vocab_size, average='macro', task='multiclass').to(DEVICE)

    for epoch in range(EPOCHS):

        # roberta like
        vectorized_mcc_train_masked, vectorized_mcc_train_mlmmask = mask_input(vectorized_mcc_train)
        trainset = TrxDataset(
            vectorized_mcc_train_masked, vectorized_mcc_train_attnmask, vectorized_mcc_train_mlmmask, vectorized_mcc_train
        )
        trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)

        model.train()
        total_loss, total_acc, total_um_acc = 0, 0, 0
        all_preds = []
        all_targets = []
        for vectorized_mcc_masked, vectorized_mcc_attnmask, vectorized_mcc_mlmmask, vectorized_mcc_target in tqdm(trainloader, desc=f"Training Epoch {epoch+1}"):
            optimizer.zero_grad()
            outputs = model(vectorized_mcc_masked, vectorized_mcc_attnmask)
            loss = criterion(outputs, vectorized_mcc_target.long(), vectorized_mcc_mlmmask)
            loss.backward()
            optimizer.step()
            scheduler.step()
            total_loss += loss.item()
            acc = MaskedAccuracy.compute(outputs, vectorized_mcc_target, vectorized_mcc_mlmmask.int())
            acc_um = UnmaskedAccuracy.compute(outputs, vectorized_mcc_target, vectorized_mcc_mlmmask.int())
            total_acc += acc.item()
            total_um_acc += acc_um.item()
            attn_mask = (vectorized_mcc_mlmmask == 1)
            f1_preds = torch.argmax(outputs, dim=-1)[attn_mask]
            f1_targets = vectorized_mcc_target[attn_mask]
            all_preds.append(f1_preds)
            all_targets.append(f1_targets)
        all_preds = torch.cat(all_preds)
        all_targets = torch.cat(all_targets)
        f1_metric.reset()
        f1_train = f1_metric(all_preds.to(DEVICE), all_targets.to(DEVICE))
        train_loss_avg = total_loss / len(trainloader)
        train_acc_avg = total_acc / len(trainloader)
        train_um_acc_avg = total_um_acc / len(trainloader)
        f1_train_avg = f1_train.item()
        print(f"[TRAIN] Epoch {epoch + 1}, Loss: {train_loss_avg:.4f}, Accuracy: {train_acc_avg:.4f}, "
            f"Unmasked accuracy: {train_um_acc_avg:.4f}, F1-Macro: {f1_train_avg:.4f}")
        mlflow.log_metrics({
            "train_loss": train_loss_avg,
            "train_accuracy": train_acc_avg,
            "train_unmasked_accuracy": train_um_acc_avg,
            "train_f1_macro": f1_train_avg
        }, step=epoch)


        model.eval()
        total_loss, total_acc, total_um_acc = 0, 0, 0
        all_preds = []
        all_targets = []
        for vectorized_mcc_masked, vectorized_mcc_attnmask, vectorized_mcc_mlmmask, vectorized_mcc_target in tqdm(valloader, desc=f"validating Epoch {epoch+1}"):
            outputs = model(vectorized_mcc_masked, vectorized_mcc_attnmask)
            loss = criterion(outputs, vectorized_mcc_target.long(), vectorized_mcc_mlmmask)
            total_loss += loss.item()
            acc = MaskedAccuracy.compute(outputs, vectorized_mcc_target, vectorized_mcc_mlmmask.int())
            acc_um = UnmaskedAccuracy.compute(outputs, vectorized_mcc_target, vectorized_mcc_mlmmask.int())
            total_acc += acc.item()
            total_um_acc += acc_um.item()
            attn_mask = (vectorized_mcc_mlmmask == 1)
            f1_preds = torch.argmax(outputs, dim=-1)[attn_mask]
            f1_targets = vectorized_mcc_target[attn_mask]
            all_preds.append(f1_preds)
            all_targets.append(f1_targets)
        all_preds = torch.cat(all_preds)
        all_targets = torch.cat(all_targets)
        f1_metric.reset()
        f1_val = f1_metric(all_preds.to(DEVICE), all_targets.to(DEVICE))
        val_loss_avg = total_loss / len(valloader)
        val_acc_avg = total_acc / len(valloader)
        val_um_acc_avg = total_um_acc / len(valloader)
        f1_val_avg = f1_val.item()
        print(f"[val] Epoch {epoch + 1}, Loss: {val_loss_avg:.4f}, Accuracy: {val_acc_avg:.4f}, "
            f"Unmasked accuracy: {val_um_acc_avg:.4f}, F1-Macro: {f1_val_avg:.4f}")
        mlflow.log_metrics({
            "val_loss": val_loss_avg,
            "val_accuracy": val_acc_avg,
            "val_unmasked_accuracy": val_um_acc_avg,
            "val_f1_macro": f1_val_avg
        }, step=epoch)
        

    mlflow.end_run()
    torch.save(model.state_dict(), f"task3_recsys/models/mcc_encoder_L{NUM_LAYERS}_D{MODEL_DIM}_FF{FF_DIM}_H{NUM_HEADS}_MLM_0_2plus_fix.pth")
    # plus - доп. MLM на последний токен
    # fix - после исправления бага
    del model
    gc.collect()
    torch.cuda.empty_cache()
    gc.collect()

Лучшая модель на комбинациях выше - ```mcc_encoder_L2_D256_FF256_H4_MLM_0_2plus_fix.pth```

---

Далее - уберем "голову" и сделаем avg pooling по скрытым состояниям до этой головы. Это - эмбеддинги, которые далее мы будем использовать для предсказания категории последней транзакции (MCC). 

Перезапустим ноутбук, загрузим эту модель.

In [1]:
import pandas as pd
# from tokenizers.implementations import BertWordPieceTokenizer
from transformers import PreTrainedTokenizerFast
import torch
import torch.nn as nn
import numpy as np
import gc
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

train_df = pd.read_parquet("task3_recsys/df_for_transformer_TRAIN.parquet")
test_df = pd.read_parquet("task3_recsys/df_for_transformer_TEST.parquet")

tokenizer = PreTrainedTokenizerFast(tokenizer_file="task3_recsys/models/tokenizer/vocab_256_bpe.json") # наиболее адекватное распределение по сравнению с большим кол-вом токенов
special_tokens = {"mask_token": "[MASK]", "cls_token": "[CLS]", "sep_token": "[SEP]", "pad_token": "[PAD]"}
tokenizer.add_special_tokens(special_tokens)

vectorized_mcc_train0 = tokenizer(list(train_df.iloc[:250000].mcc_cat_list.values), return_tensors="pt", max_length=256, 
                           padding="max_length", truncation=True, padding_side="left", return_token_type_ids=False)
vectorized_mcc_train0, vectorized_mcc_train0_attnmask = vectorized_mcc_train0.input_ids.to(torch.int16), vectorized_mcc_train0.attention_mask.to(torch.bool)
gc.collect()

vectorized_mcc_train1 = tokenizer(list(train_df.iloc[250000:-250000].mcc_cat_list.values), return_tensors="pt", max_length=256, 
                           padding="max_length", truncation=True, padding_side="left", return_token_type_ids=False)
vectorized_mcc_train1, vectorized_mcc_train1_attnmask = vectorized_mcc_train1.input_ids.to(torch.int16), vectorized_mcc_train1.attention_mask.to(torch.bool)
gc.collect()

vectorized_mcc_train2 = tokenizer(list(train_df.iloc[-250000:].mcc_cat_list.values), return_tensors="pt", max_length=256, 
                           padding="max_length", truncation=True, padding_side="left", return_token_type_ids=False)
vectorized_mcc_train2, vectorized_mcc_train2_attnmask = vectorized_mcc_train2.input_ids.to(torch.int16), vectorized_mcc_train2.attention_mask.to(torch.bool)
gc.collect()

vectorized_mcc_test = tokenizer(list(test_df.mcc_cat_list.values), return_tensors="pt", max_length=256, 
                           padding="max_length", truncation=True, padding_side="left")
vectorized_mcc_test, vectorized_mcc_test_attnmask = vectorized_mcc_test.input_ids.to(torch.int16), vectorized_mcc_test.attention_mask.to(torch.bool)

vectorized_mcc_train = torch.concatenate([vectorized_mcc_train0, vectorized_mcc_train1, vectorized_mcc_train2], dim=0)
vectorized_mcc_train_attnmask = torch.concatenate([vectorized_mcc_train0_attnmask, vectorized_mcc_train1_attnmask, vectorized_mcc_train2_attnmask], dim=0)

assert vectorized_mcc_train.shape == vectorized_mcc_train_attnmask.shape

In [2]:
assert tokenizer.get_vocab()["[PAD]"] == 259, "padding not aligned"

Re-declare всех нужных классов и функций для инференса:

In [3]:
PADTOKEN = 259

def get_positional_encoding(seq_len, model_dim):
    positions = torch.arange(seq_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, model_dim, 2) * -(np.log(10000.0) / model_dim))
    positional_encodings = torch.zeros((seq_len, model_dim))
    positional_encodings[:, 0::2] = torch.sin(positions * div_term)
    positional_encodings[:, 1::2] = torch.cos(positions * div_term)
    return positional_encodings


class TransformerEncoder(nn.Module):
    def __init__(self, model_dim, num_heads, ff_dim, dropout=0.1):
        super(TransformerEncoder, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=model_dim, num_heads=num_heads, dropout=dropout)
        self.layer_norm1 = nn.LayerNorm(model_dim)
        self.ffn = nn.Sequential(
            nn.Linear(model_dim, ff_dim),
            nn.GELU(),
            nn.Linear(ff_dim, model_dim)
        )
        self.layer_norm2 = nn.LayerNorm(model_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, attn_mask):
        attn_output, _ = self.attention(query=x, key=x, value=x, key_padding_mask=attn_mask)
        x = self.layer_norm1(x + self.dropout(attn_output))
        ffn_output = self.ffn(x)
        x = self.layer_norm2(x + self.dropout(ffn_output))
        return x
    
class TransformerModel(nn.Module):
    def __init__(self, vocab_size, model_dim, num_heads, ff_dim, dropout, max_seq_len, num_layers):
        super(TransformerModel, self).__init__()
        self.num_segments = 2
        self.embedding = nn.Embedding(vocab_size, model_dim, padding_idx=PADTOKEN)
        self.positional_encoding = get_positional_encoding(max_seq_len, model_dim)
        self.encoders = nn.ModuleList(
            [TransformerEncoder(model_dim, num_heads, ff_dim, dropout) for _ in range(num_layers)]
        )
        self.fc = nn.Linear(model_dim, vocab_size)
        # self.register_buffer("positional_encoding", get_positional_encoding(max_seq_len, model_dim))

    def forward(self, x, attn_mask):
        embeddings = self.embedding(x)
        embeddings += self.positional_encoding[:embeddings.size(1), :].to(embeddings.device)
        x = embeddings.permute(1, 0, 2)
        for enc in self.encoders:
            x = enc(x, attn_mask)
        x = self.fc(x.permute(1, 0, 2))
        return x
    
    def get_mlm_preds(self, x, mask, attn_mask):
        x = self.forward(x, attn_mask)
        return torch.argmax(x[(mask == 1)], dim=-1) 

    def get_hidden_states(self, x, attn_mask):
        embeddings = self.embedding(x)
        embeddings += self.positional_encoding[:embeddings.size(1), :].to(embeddings.device)
        x = embeddings.permute(1, 0, 2)
        for enc in self.encoders:
            x = enc(x, attn_mask)
        return x.permute(1, 0, 2)

class EmbedTransformerModel(nn.Module):
    def __init__(self, transformer_model):
        super(EmbedTransformerModel, self).__init__()
        self.transformer = transformer_model

    def forward(self, x):
        inputs, attn_mask = x[0].int(), x[1].to(bool)
        inputs, attn_mask = inputs.to(DEVICE), attn_mask.to(DEVICE).to(DEVICE)
        hidden_states = self.transformer.get_hidden_states(inputs, attn_mask)  # (batch_size, seq_len, model_dim)
        attention_mask = (attn_mask != 1).unsqueeze(-1)  # (batch_size, seq_len, 1)
        masked_hidden_states = hidden_states * attention_mask  # apply mask
        avg_pooled_embeddings = masked_hidden_states.sum(dim=1) / (attention_mask.sum(dim=1) + 1e-8)
        return avg_pooled_embeddings
        
MODEL_DIM = 256
NUM_HEADS = 4
FF_DIM = 256
DROPOUT = .05
NUM_LAYERS = 2
MAX_INPUT_LEN = 256
DEVICE = "cuda"
    
model = TransformerModel(len(tokenizer.get_vocab()), MODEL_DIM, NUM_HEADS, FF_DIM, DROPOUT, MAX_INPUT_LEN, NUM_LAYERS)
model.load_state_dict(torch.load("task3_recsys/models/mcc_encoder_L2_D256_FF256_H4_MLM_0_2plus.pth", weights_only=True, map_location=DEVICE))
model.to(DEVICE)
    
embed_model = EmbedTransformerModel(model)
embed_model.to(DEVICE)
embed_model.eval();

Будем смотреть результат и финальные метрики на тест-части и эмбедить ее же. Но оптимизировать саму задачу рекомендации следующего MCC нужно на трейн части, поэтому эмбедим ее тоже.

In [4]:
BATCH_SIZE = 16

trainset = TensorDataset(vectorized_mcc_train, ~vectorized_mcc_train_attnmask)
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=False)

valset = TensorDataset(vectorized_mcc_test, ~vectorized_mcc_test_attnmask)
valloader = DataLoader(valset, batch_size=BATCH_SIZE, shuffle=False)

In [5]:
train_embs = []

with torch.no_grad():
    for c, batch in enumerate(tqdm(trainloader)):
        avg_pooled_embeddings = embed_model(batch)
        embs_ = avg_pooled_embeddings.cpu().detach().numpy()
        train_embs.append(embs_)

train_embs = np.concatenate(train_embs)

val_embs = []

with torch.no_grad():
    for c, batch in enumerate(tqdm(valloader)):
        avg_pooled_embeddings = embed_model(batch)
        embs_ = avg_pooled_embeddings.cpu().detach().numpy()
        val_embs.append(embs_)

val_embs = np.concatenate(val_embs)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 48190/48190 [01:36<00:00, 497.43it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12048/12048 [00:23<00:00, 502.15it/s]


In [12]:
final_train_df = train_df.join(pd.DataFrame(train_embs))
final_test_df = test_df.join(pd.DataFrame(val_embs))

In [17]:
final_train_df.to_parquet("task3_recsys/train_df_after_transformer.parquet")
final_test_df.to_parquet("task3_recsys/test_df_after_transformer.parquet")

  table = self.api.Table.from_pandas(df, **from_pandas_kwargs)
