In [None]:
!nvidia-smi

In [None]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
!pip install transformers torchtext datasets fsspec==2023.6.0

In [None]:
import torch
from transformers import BertTokenizer, BertModel
from datasets import load_dataset
import gc

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
!rm -rf ~/.cache/huggingface/datasets/imdb
!rm -rf ~/.cache/huggingface/modules

In [None]:
main_dataset = load_dataset('imdb', revision="main")

In [None]:
import torch
from transformers import BertTokenizer, BertModel
from datasets import load_dataset
import gc
from torch.utils.data import DataLoader

class MultiHeadAttention(torch.nn.Module):
    def __init__(self, input_dim, num_heads=8):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_k = input_dim // num_heads

        self.Q = torch.nn.Linear(input_dim, input_dim)
        self.K = torch.nn.Linear(input_dim, input_dim)
        self.V = torch.nn.Linear(input_dim, input_dim)

        self.fc = torch.nn.Linear(input_dim, input_dim)

    def forward(self, x, mask):
        # Implement the forward pass
        # x should be of shape (batch_size, seq_length, input_dim)
        batch_size, seq_length, _ = x.shape
        Q = self.Q(x).view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
        K = self.K(x).view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
        V = self.V(x).view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)

        # new format: (batch_size, num_heads, seq_length, d_k)

        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (K.size(-1) ** 0.5)

        if mask is not None:
            # Apply the mask to the attention scores
            mask = mask.unsqueeze(1).unsqueeze(2)
            attention_scores = attention_scores.masked_fill(mask == 0, float('-inf'))

        # format: (batch_size, num_heads, seq_length, seq_length)

        attention_weights = torch.nn.functional.softmax(attention_scores, dim=-1)

        context = torch.matmul(attention_weights, V)

        # format: (batch_size, num_heads, seq_length, d_k)

        context = context.transpose(1, 2).contiguous().view(batch_size, seq_length, -1)

        # format: (batch_size, seq_length, input_dim)

        # after concatnating the heads we run a linear layer to allow the different parts of the embeddings to interact post attention

        output = self.fc(context)
        return output


class FeedForwardNetwork(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim=128):
        super(FeedForwardNetwork, self).__init__()
        self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, input_dim)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x


class SentimentAnalysisModel(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim=128, output_dim=1):
        super(SentimentAnalysisModel, self).__init__()

        self.multi_head_attention = MultiHeadAttention(input_dim)
        self.norm1 = torch.nn.LayerNorm(input_dim)
        self.dropout1 = torch.nn.Dropout(0.1)
        self.feed_forward = FeedForwardNetwork(input_dim, hidden_dim)
        self.norm2 = torch.nn.LayerNorm(input_dim)
        self.dropout2 = torch.nn.Dropout(0.1)

        self.multi_head_attention2 = MultiHeadAttention(input_dim)
        self.norm3 = torch.nn.LayerNorm(input_dim)
        self.dropout3 = torch.nn.Dropout(0.1)
        self.feed_forward2 = FeedForwardNetwork(input_dim, hidden_dim)
        self.norm4 = torch.nn.LayerNorm(input_dim)
        self.dropout4 = torch.nn.Dropout(0.1)

        # self.output_layer = torch.nn.Linear(input_dim, output_dim)
        # self.dropout5 = torch.nn.Dropout(0.1)
        # self.sigmoid = torch.nn.Sigmoid()

        # Output layers
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.2),
            torch.nn.Linear(hidden_dim, output_dim)
        )

        self.optimizer = torch.optim.Adam(self.parameters(),  lr=2e-5, weight_decay=1e-4)


    def forward(self, x, mask):
        # x should be of shape (batch_size, seq_length, input_dim)
        residual = x
        x = self.multi_head_attention(x, mask)
        x = self.dropout1(x)
        x = self.norm1(x + residual)

        residual = x
        x = self.feed_forward(x)
        x = self.dropout2(x)
        x = self.norm2(x + residual)

        residual = x
        x = self.multi_head_attention2(x, mask)
        x = self.dropout3(x)
        x = self.norm3(x + residual)

        residual = x
        x = self.feed_forward2(x)
        x = self.dropout4(x)
        x = self.norm4(x + residual)

        mask_expanded = mask.unsqueeze(-1).float()
        x = x * mask_expanded
        pooled = x.sum(dim=1) / mask.sum(dim=1, keepdim=True).clamp(min=1e-9)

        # x = x.mean(dim=1)

        # x = self.output_layer(pooled)
        # x = self.dropout5(x)
        # x = self.sigmoid(x)

        x = self.classifier(pooled).squeeze(-1)

        return x

    def compute_loss(self, logits, targets):
        loss_fn =  torch.nn.BCEWithLogitsLoss()
        targets = targets.float()
        return loss_fn(logits.squeeze(), targets)

    def train_step(self, inputs, mask, targets):
      self.train()
      self.optimizer.zero_grad()

      # Forward pass
      logits = self(inputs, mask)

      # Compute loss
      loss = self.compute_loss(logits, targets)

      # Backward pass and optimize
      loss.backward()
      torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0)
      self.optimizer.step()

      return loss.item()

    def test_model(self, inputs, mask, targets):
        self.eval()
        total_loss = 0.0
        correct_predictions = 0
        total_predictions = 0

        with torch.no_grad():
          # Remove extra dimensions if present
          inputs = inputs.squeeze(0) if inputs.dim() > 3 else inputs
          mask = mask.squeeze(0) if mask.dim() > 2 else mask
          targets = targets.squeeze(0) if targets.dim() > 1 else targets

          # Forward pass
          logits = self(inputs, mask)
          logits = torch.sigmoid(logits)

          # Compute loss
          loss = self.compute_loss(logits, targets)
          total_loss += loss.item()

          # Compute accuracy
          predictions = (logits.squeeze() > 0.5).float()
          correct_predictions += (predictions == targets.float()).sum().item()
          total_predictions += targets.size(0)

        avg_loss = total_loss / inputs.shape[0]
        accuracy = correct_predictions / total_predictions

        # print("Predicted: ", logits)
        # print("Actual: ", targets)
        print(avg_loss, accuracy)

        return avg_loss, accuracy

class Databuilder:
    def __init__(self, batch_size=32):
        self.dataset = main_dataset
        self.train_data = self.dataset['train'].shuffle(seed=42)
        self.test_data = self.dataset['test'].shuffle(seed=42)
        self.batch_size = batch_size
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.embedder = BertModel.from_pretrained('bert-base-uncased').to(device)

        for param in self.embedder.parameters():
            param.requires_grad = False

    def get_loader(self, type='train'):
      batch_texts = []
      batch_labels = []

      loaded_data = self.dataset[type].shuffle()
      for example in loaded_data:

          batch_texts.append(example['text'])
          batch_labels.append(int(example['label']))

          if len(batch_texts) == self.batch_size:
              yield self.vectorize_batch(batch_texts, batch_labels)

              gc.collect()

              batch_texts = []
              batch_labels = []

    def vectorize_batch(self, batch_texts, batch_labels):
      tokens = self.tokenizer(batch_texts, padding=True, truncation=True, return_tensors='pt', max_length=512)

      input_ids = tokens['input_ids'].to(device)
      attention_mask = tokens['attention_mask'].to(device)

      with torch.no_grad():
        embeddings = self.embedder(input_ids).last_hidden_state

      embeddings = embeddings.detach()

      labels = torch.tensor(batch_labels, dtype=torch.int8)

      return (embeddings, attention_mask, labels)



In [None]:
print("Loading dataset...")
epochs = 10

databuilder = Databuilder(batch_size=128)

transformer = SentimentAnalysisModel(input_dim=768, hidden_dim=2048, output_dim=1).to(device)

test_batch = next(databuilder.get_loader('test'))

ti, tm, tt = test_batch
ti = ti.to(device)
tm = tm.to(device)
tt = tt.to(device)

for epoch in range(epochs):
  print(epoch)
  for batch in databuilder.get_loader('train'):
    inputs, masks, targets = batch

    inputs = inputs.to(device)
    masks = masks.to(device)
    targets = targets.to(device)

    loss = transformer.train_step(inputs, masks, targets)

    loss, acc = transformer.test_model(ti, tm, tt)

    # if acc >= 0.9:
    #   break

    gc.collect()

print("Training complete.")

In [None]:
# Cleanup
gc.collect()
torch.cuda.empty_cache()

In [None]:
# Test

test_batches = ['The movie was bad', "I want my money back!", "OMG I WAS SCREAMING, IT WAS THAT GOOD!", 'AMAZING MOVIE!']

# databuilder = Databuilder(batch_size=64)

# transformer = SentimentAnalysisModel(input_dim=768, hidden_dim=2048, output_dim=1).to(device)

i, m, l = databuilder.vectorize_batch(test_batches, [])

# print(i,m,l)

print(torch.sigmoid(transformer.forward(i, m)))

for text, sentiment in zip(test_batches, torch.sigmoid(transformer.forward(i, m))):
  sentiment = "Good" if sentiment > 0.5 else "Bad"
  print(text, "=>", sentiment)

