In [None]:
import torch as T
import torch.nn as nn
from torchtext.legacy import data, datasets
import torch.optim as optim
import time

In [None]:
class Embedding(nn.Module):
  def __init__(self, vocab_size, max_length, embed_dim, dropout=0.1):
    super(Embedding, self).__init__()
    self.word_embed = nn.Embedding(vocab_size, embed_dim)
    self.pos_embed = nn.Embedding(max_length, embed_dim)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    batch_size, seq_length = x.shape
    device = T.device('cuda' if T.cuda.is_available() else 'cpu')
    positions = T.arange(0, seq_length).expand(
        batch_size, seq_length).to(device)
    embedding = self.word_embed(x) + self.pos_embed(positions)
    return self.dropout(embedding)

In [None]:
class MHSelfAttention(nn.Module):
  def __init__(self, embed_dim, num_heads):
    super(MHSelfAttention, self).__init__()
    self.embed_dim = embed_dim
    self.num_heads = num_heads
    self.head_dim = embed_dim // num_heads

  

    self.w_queries = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
    self.w_keys = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
    self.w_values = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
    
    self.fc_out = nn.Linear(self.head_dim*self.num_heads , self.embed_dim)

  def forward(self, x):

    batch_size = x.shape[0]
    sentence_len = x.shape[1]

    queries = self.w_queries(x).reshape(
        batch_size, sentence_len, self.num_heads, self.head_dim).permute(
            0, 2, 1, 3)
    
    keys = self.w_keys(x).reshape(
        batch_size, sentence_len, self.num_heads, self.head_dim).permute(
            0, 2, 3, 1)
    
    
    values = self.w_values(x).reshape(
        batch_size, sentence_len, self.num_heads, self.head_dim).permute(
            0, 2, 1, 3)
    
    attention_scores = T.einsum('bijk,bikl->bijl', queries, keys)
    attention_dist = T.softmax(attention_scores /
                               (self.embed_dim ** (1/2)), dim=-1)
    attention_out = T.einsum('bijk,bikl->bijl', attention_dist, values)
    concatenated_out = attention_out.permute(0, 2, 1, 3).reshape(
        batch_size, sentence_len, self.embed_dim)
    
    return concatenated_out

In [None]:
class TransformerEncoder(nn.Module):
  def __init__(self, embed_dim, num_heads, forward_expansion, dropout=0.1):
    super(TransformerEncoder, self).__init__()

    self.attention = MHSelfAttention(embed_dim, num_heads)
    self.norm1 = nn.LayerNorm(embed_dim)
    self.norm2 = nn.LayerNorm(embed_dim)

    self.feed_forward = nn.Sequential(
        nn.Linear(embed_dim, forward_expansion*embed_dim),
        nn.ReLU(),
        nn.Linear(forward_expansion*embed_dim, embed_dim)
    )
    
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    attention_out = self.dropout(self.attention(x))
    x = self.norm1(x + attention_out)  
    forward_out = self.dropout(self.feed_forward(x))
    out = self.norm2(x + forward_out)

    return out

In [None]:
class Classifier(nn.Module):
  def __init__(self, vocab_size, max_length, embed_dim,
               num_heads, forward_expansion):
      super(Classifier, self).__init__()

      self.embedder = Embedding(vocab_size, max_length, embed_dim)
      self.encoder = TransformerEncoder(embed_dim, num_heads, forward_expansion)
      self.fc = nn.Linear(embed_dim, 1)

  def forward(self, x):
    embedding = self.embedder(x)
    encoding = self.encoder(embedding)
    compact_encoding = encoding.max(dim=1)[0]
    out = self.fc(compact_encoding)
    return out

In [None]:
TEXT = data.Field(
    tokenize = 'spacy', tokenizer_language = 'en_core_web_sm',batch_first=True)
LABEL = data.LabelField(dtype = T.float)

In [None]:
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)

In [None]:
train_data, valid_data = train_data.split()

In [None]:
MAX_VOCAB_SIZE = 25000

TEXT.build_vocab(train_data, max_size = MAX_VOCAB_SIZE)
LABEL.build_vocab(train_data)

In [None]:
BATCH_SIZE = 4
device = T.device('cuda' if T.cuda.is_available() else 'cpu')

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE,
    device = device)

### Training

In [None]:
optimizer = optim.SGD(classifier.parameters(), lr=1e-3)

In [None]:
criterion = nn.BCEWithLogitsLoss()
device = T.device('cuda' if T.cuda.is_available() else 'cpu')
criterion.to(device);

In [None]:
def binary_accuracy(preds, y):
    rounded_preds = T.round(T.sigmoid(preds))
    correct = (rounded_preds == y).float()
    acc = correct.sum() / len(correct)
    return acc

In [None]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
def train(model, iterator, optimizer, criterion):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.train()
    
    for batch in iterator:
        
        optimizer.zero_grad()

        input = batch.text
        if input.shape[1] > MAX_LENGTH:
          input = input[:, :MAX_LENGTH]
                
        predictions = model(input).squeeze(1)
        
        loss = criterion(predictions, batch.label)
        
        acc = binary_accuracy(predictions, batch.label)
        
        loss.backward()
        
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [None]:
def evaluate(model, iterator, criterion):
    epoch_loss = 0
    epoch_acc = 0
    model.eval()

    with T.no_grad():
        for batch in iterator:

            input = batch.text
            if input.shape[1] > MAX_LENGTH:
              input = input[:, :MAX_LENGTH]
                    
            predictions = model(input).squeeze(1)
            loss = criterion(predictions, batch.label)
            acc = binary_accuracy(predictions, batch.label)
            epoch_loss += loss.item()
            epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [None]:
N_EPOCHS = 10
best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    start_time = time.time()

    train_loss, train_acc = train(
        classifier, train_iterator,optimizer, criterion)
    valid_loss, valid_acc = evaluate(classifier, valid_iterator, criterion)

    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        T.save(classifier.state_dict(), 'sent-classifier.pt')

    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

In [None]:
classifier.load_state_dict(T.load('sent-classifier.pt'))

In [None]:
test_loss, test_acc = evaluate(classifier, test_iterator, criterion)
print(f'Test Loss: {test_loss:.3f} |  Test Acc: {test_acc*100:.2f}%')