In [1]:
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
import torch
from torch import Tensor, nn
import torch.nn.functional as F
from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel


In [2]:
%pip install datasets==2.21.0

Collecting datasets==2.21.0
  Downloading datasets-2.21.0-py3-none-any.whl.metadata (21 kB)
Collecting fsspec<=2024.6.1,>=2023.1.0 (from fsspec[http]<=2024.6.1,>=2023.1.0->datasets==2.21.0)
  Downloading fsspec-2024.6.1-py3-none-any.whl.metadata (11 kB)
Downloading datasets-2.21.0-py3-none-any.whl (527 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m527.3/527.3 kB[0m [31m8.5 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading fsspec-2024.6.1-py3-none-any.whl (177 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m177.6/177.6 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fsspec, datasets
  Attempting uninstall: fsspec
    Found existing installation: fsspec 2024.9.0
    Uninstalling fsspec-2024.9.0:
      Successfully uninstalled fsspec-2024.9.0
  Attempting uninstall: datasets
    Found existing installation: datasets 3.2.0
    Uninstalling datasets-3.2.0:
      Successfully uninstalled datasets-3.2.0

In [3]:
from datasets import load_dataset

train_dataset = load_dataset("bentrevett/multi30k", split="train")
test_dataset = load_dataset("bentrevett/multi30k", split="test")

Downloading readme:   0%|          | 0.00/1.15k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.60M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/164k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/156k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/29000 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1014 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [5]:
def filter_dataset(dataset, minlen: int, maxlen: int) -> list[dict[str, str]]:
    return [
        dataset[i]["en"]
        for i in range(len(dataset))
        if len(dataset[i]["en"].split(" ")) <= maxlen and len(dataset[i]["en"].split(" "))>=minlen
    ]

In [6]:
maxlen = 30
minlen = 5
train_filtered = filter_dataset(train_dataset, minlen, maxlen)
test_filtered = filter_dataset(test_dataset, minlen, maxlen)
print(len(train_filtered), len(test_filtered))

28945 997


In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)

<torch._C.Generator at 0x7afadc10a830>

In [8]:
tokenizer = BertTokenizer.from_pretrained("prajjwal1/bert-tiny")
bert_model = BertModel.from_pretrained("prajjwal1/bert-tiny").to(device)
embedding_matrix = bert_model.embeddings.word_embeddings.weight
embedding_matrix.shape

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/285 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/17.8M [00:00<?, ?B/s]

torch.Size([30522, 128])

In [10]:
def collate_fn(
    tokenizer, batch: list[tuple[str, str]]
) -> tuple[Tensor, Tensor]:
    encoded_batch = tokenizer.batch_encode_plus(
        batch, padding="longest", return_tensors="pt", return_attention_mask=False)['input_ids']
    return encoded_batch

In [11]:
train_loader = DataLoader(train_filtered, batch_size=32, shuffle=True, collate_fn=lambda batch:collate_fn(tokenizer,batch))
test_loader = DataLoader(test_filtered, batch_size=32, shuffle=True, collate_fn=lambda batch:collate_fn(tokenizer,batch))
tokens = next(iter(train_loader))
tokens.shape

torch.Size([32, 26])

In [12]:
class RNNCell(nn.Module):
    """
    (x_{t}, h_{t-1}) -> h_{t}
    """
    def __init__(self, input_dim: int, hidden_dim: int) -> None:
        super().__init__()
        self.linear = nn.Linear(input_dim+hidden_dim, hidden_dim)
        torch.nn.init.kaiming_normal_(self.linear.weight, nonlinearity='tanh')

    def forward(self, x: Tensor, h: Tensor) -> Tensor:
        # x: B x input_dim
        # h: B x hidden_dim
        h = torch.cat([x, h], dim=1)
        h = self.linear(h)
        return F.tanh(h)

In [13]:
class RNN_encoder(nn.Module):
    def __init__(self, vocab_size: int, input_dim: int = 128, hidden_dim: int = 512, cell: nn.Module = RNNCell) -> None:
        super().__init__()
        self.embed = nn.Embedding.from_pretrained(embedding_matrix, freeze=True)
        self.init_h = nn.Parameter(data=torch.randn(1, hidden_dim))
        self.rnn = cell(input_dim, hidden_dim)


    def forward(self, x: Tensor) -> Tensor:
        # x: B x T
        # embed(x): B x T -> B x T x input_dim
        B, T = x.shape

        x = self.embed(x)  # B x T x input_dim
        h = self.init_h.expand((B, -1)) # B x hidden_dim

        for t in range(T):
            xt = x[:, t, :]
            h = self.rnn.forward(xt, h)  # B x hidden

        return h, x # B x hidden, B x T x input_dim, с x будем считать mse

In [14]:
tokens.shape

torch.Size([32, 27])

In [14]:
encoder = RNN_encoder(
    vocab_size=len(tokenizer)
)
sent_emb,_ = encoder.forward(tokens)
print(sent_emb.shape)

torch.Size([32, 512])


In [15]:
#На выходе получаем BxTxtoken_emb_dim, и будем приближать выход ко входу через mse и считать cos_sim для accuracy 
class RNN_decoder(nn.Module):
    def __init__(self, output_dim: int = 128, hidden_dim: int = 512, cell: nn.Module = RNNCell) -> None:
        super().__init__()
        self.init_h = nn.Parameter(data=torch.randn(1, hidden_dim))
        self.rnn = cell(hidden_dim, hidden_dim)
        self.lm_head = nn.Linear(hidden_dim, output_dim)


    def forward(self, x: Tensor, T: int) -> Tensor:
        # x: B x S (S - sentence_emb_dim)
        B, S = x.shape
        x = x.unsqueeze(1).expand((B, T, S))

        h = self.init_h.expand((B, -1)) # B x hidden_dim

        res = []
        for t in range(T):
            xt = x[:, t, :]
            h = self.rnn.forward(xt, h)  # B x hidden
            y = self.lm_head(h).unsqueeze(1)  # B x 1 x hidden
            res.append(y)

        return torch.cat(res, dim=1) # B x T x token_dim(word_emb_dim)

In [40]:
decoder = RNN_decoder()
decoder.forward(sent_emb, tokens.shape[1]).shape

torch.Size([32, 25, 128])

In [16]:
class SentenceAutoEncoder(nn.Module):
  def __init__(self, hidden_dim:int=512)-> None:
    super().__init__()
    self.encoder = RNN_encoder(vocab_size=len(tokenizer),hidden_dim=hidden_dim)
    self.l = nn.Linear(hidden_dim, hidden_dim)
    self.dropout = nn.Dropout(0.1)
    self.decoder = RNN_decoder(hidden_dim=hidden_dim)

  def forward(self, x: Tensor)-> Tensor:
    B, T = x.shape
    sent_emb, word_emb  = self.encoder(x)

    output_word_emb = self.decoder(self.dropout(F.tanh(self.l(sent_emb))), T)

    return word_emb, output_word_emb



In [36]:
embedding_matrix.shape

torch.Size([30522, 128])

In [40]:
autoencoder = SentenceAutoEncoder()
word_emb, output_word_emb = autoencoder.forward(tokens)
word_emb.shape, output_word_emb.shape

(torch.Size([32, 26, 128]), torch.Size([32, 26, 128]))

In [24]:
torch.matmul(output_word_emb, embedding_matrix.T).shape

torch.Size([32, 26, 30522])

In [17]:
#Для токенов, output_emb которых находится в пространстве для input_emb(ближайший emb из emb_matrix это input_emb)
#loss не считается, в противном случае loss=mse(output_emb, input_emb)
def match_loss(tokens, word_emb, output_word_emb, all_emb):
  with torch.no_grad():
    output_word_emb_norm = output_word_emb / output_word_emb.norm(dim=-1, keepdim=True)
    all_emb_norm = all_emb / all_emb.norm(dim=-1, keepdim=True) 
    #Ближайший emb из emb_matrix по cos_sim (считать ближайший по mse сильно затратно)
    best_similarity =  torch.matmul(output_word_emb_norm, all_emb_norm.T).argmax(-1)
    mask =  best_similarity != tokens
   
  loss = (mask * ((output_word_emb-word_emb)**2).sum(-1)).sum()
  return loss, (best_similarity == tokens).float().sum()

In [33]:
match_loss(tokens, word_emb, output_word_emb, embedding_matrix)

(tensor(11367.7617, grad_fn=<SumBackward0>), tensor(0.))

In [18]:
def train_epoch(dataloader: DataLoader,model: nn.Module,optimizer: torch.optim.Optimizer):
    model.train()
    loss_total = 0
    n_total = 0
    n_correct = 0

    for tokens in dataloader:
        word_emb, output_word_emb = model(tokens.to(device))

        loss, batch_n_correct = match_loss(tokens.to(device), word_emb, output_word_emb, embedding_matrix)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        n_total += tokens.size(0)*tokens.size(1)
        loss_total += loss.item()
        n_correct += batch_n_correct.item()

    loss=loss_total / n_total
    acc = n_correct / n_total

    train_loss.append(loss)
    train_acc.append(acc)

In [19]:
@torch.no_grad()
def test_epoch(dataloader: DataLoader,model: nn.Module):
    model.eval()
    loss_total = 0
    n_total = 0
    n_correct = 0

    for tokens in dataloader:
        word_emb, output_word_emb = model(tokens.to(device))

        loss, batch_n_correct = match_loss(tokens.to(device), word_emb, output_word_emb, embedding_matrix)
        n_total += tokens.size(0)*tokens.size(1)
        loss_total += loss.item()
        n_correct += batch_n_correct.item()

    loss=loss_total / n_total
    acc = n_correct / n_total

    val_loss.append(loss)
    val_acc.append(acc)

In [20]:
torch.manual_seed(42)
model = SentenceAutoEncoder(hidden_dim=512).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
EPOCHS = 5
train_loss = []
val_loss = []
train_acc = []
val_acc = []

for epoch in range(EPOCHS):
    print(f"EPOCH {epoch}")
    train_epoch(train_loader,model,optimizer)
    test_epoch(test_loader,model)
    print(f"Train_loss={train_loss[-1]}, Val_loss={val_loss[-1]}")
    print(f"Train_accuracy={train_acc[-1]}, Val_accuracy={val_acc[-1]}")

EPOCH 0
Train_loss=0.7488839063390808, Val_loss=0.547196175785717
Train_accuracy=0.15504117452314975, Val_accuracy=0.04486268053580783
EPOCH 1
Train_loss=0.3687221738381444, Val_loss=0.4391373728230871
Train_accuracy=0.14464250683337038, Val_accuracy=0.054809505846850244
EPOCH 2


KeyboardInterrupt: 

In [None]:
#Само расстояние между выходом и входом уменьшается,
#но ближайший из emb_matrix редко является input_emb, так как из меньшего mse,
#не следует больший cos_sim, поэтому в следующей реализации используем cross_entopy.