<a href="https://colab.research.google.com/github/gunadhineha/molecularGNN_smiles/blob/master/Text_classification_Student_Notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torchdata

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchdata
  Downloading torchdata-0.5.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.6/4.6 MB[0m [31m41.7 MB/s[0m eta [36m0:00:00[0m
Collecting portalocker>=2.0.0
  Downloading portalocker-2.7.0-py2.py3-none-any.whl (15 kB)
Installing collected packages: portalocker, torchdata
Successfully installed portalocker-2.7.0 torchdata-0.5.1


In [None]:
import torch
import torchtext
from torchtext.datasets import AG_NEWS
import torch.nn as nn
import torch.nn.functional as F
import os

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

In [None]:
if not os.path.isdir('./.data'): os.mkdir('./.data')
train_dataset, test_dataset = AG_NEWS(root='./.data')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
tokenizer = get_tokenizer('basic_english')

def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_dataset), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])
train_dataset = [(text, label-1) for label, text in train_dataset]
test_dataset = [(text, label-1) for label, text in test_dataset]

In [None]:
def generate_batch(batch):
    label_list, text_list = [], []
    for (_text, _label) in batch:
        label_list.append(int(_label))
        text_list.append(torch.tensor(vocab(tokenizer(_text)), dtype=torch.int64))
    text_list = pad_sequence(text_list, padding_value=0).transpose(0,1)
    label_list = torch.tensor(label_list, dtype=torch.int64)
    return text_list.to(device), label_list.to(device)

In [None]:
vocab(['Seagate', 'ships', 'world', 'first'])

[0, 4033, 50, 47]

In [None]:
num_class = len(set([label for (text, label) in train_dataset]))
vocab_size = len(vocab)
emsize = 64

In [None]:
trainloader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=generate_batch)
testloader = DataLoader(test_dataset, batch_size=8, shuffle=True, collate_fn=generate_batch)

In [None]:
class SimplifiedAttention(nn.Module):
    def __init__(self, D):
      super(SimplifiedAttention, self).__init__()
      self.D = D
      self.q = nn.Linear(D, D)
      self.k = nn.Linear(D, D)
      self.v = nn.Linear(D, D)

    def forward(self, X):
      Q = self.q(X)
      K = self.k(X)
      V = self.v(X)
      S = Q @ torch.transpose(K, 1, 2) / (self.D ** 0.5)
      A = F.softmax(S, dim=-1)
      Y = A @ V
      return Y

In [None]:
class TransformerLayer(nn.Module):
    def __init__(self, D):
        super(TransformerLayer, self).__init__()
        self.sa = SimplifiedAttention(D)
        self.ln1 = nn.LayerNorm(D)
        self.linear1 = nn.Linear(D, 2*D)
        self.linear2 = nn.Linear(2*D, D)
        self.relu = nn.ReLU()
        self.ln2 = nn.LayerNorm(D)

    def forward(self, X):
        h = self.sa(X) + X
        h = self.ln1(h)
        h = self.linear2(self.relu(self.linear1(h))) + h
        output = self.ln2(h)
        return output

In [None]:
class Model(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super(Model, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.fc = nn.Linear(embed_dim, num_class)
        self.trans1 = TransformerLayer(embed_dim)
        self.trans2 = TransformerLayer(embed_dim)
        self.trans3 = TransformerLayer(embed_dim)

    def forward(self, input):
        embedded = self.embedding(input)
        h = self.trans1(embedded)
        h = self.trans2(h)
        h = self.trans3(h)
        avg_emb = h.mean(1)
        output = self.fc(avg_emb)
        return output

In [None]:
model = Model(vocab_size, emsize, num_class).to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

In [None]:
def train(dataloader):
    model.train()
    total_acc, total_count = 0, 0
    log_interval = 500

    for idx, (input, target) in enumerate(dataloader):
        output = model(input)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_acc += (output.argmax(1) == target).sum().item()
        total_count += target.size(0)
        if idx % log_interval == 0 and idx > 0:
            print('Train acc: {:8.3f} (batch {:5d}/{:5d})'.format(total_acc/total_count, idx, len(dataloader)))
            
def test(dataloader):
    model.eval()
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for idx, (input, target) in enumerate(dataloader):
            output = model(input)
            total_acc += (output.argmax(1) == target).sum().item()
            total_count += target.size(0)
    print('Test acc: {:8.3f}'.format(total_acc/total_count))

In [None]:
for epoch in range(5):
    print("Epoch {}".format(epoch))
    train(trainloader)
    test(testloader)

Epoch 0
Train acc:    0.488 (batch   500/15000)
Train acc:    0.600 (batch  1000/15000)
Train acc:    0.661 (batch  1500/15000)
Train acc:    0.698 (batch  2000/15000)
Train acc:    0.725 (batch  2500/15000)
Train acc:    0.745 (batch  3000/15000)
Train acc:    0.759 (batch  3500/15000)
Train acc:    0.769 (batch  4000/15000)
Train acc:    0.778 (batch  4500/15000)
Train acc:    0.787 (batch  5000/15000)
Train acc:    0.794 (batch  5500/15000)
Train acc:    0.801 (batch  6000/15000)
Train acc:    0.807 (batch  6500/15000)
Train acc:    0.812 (batch  7000/15000)
Train acc:    0.817 (batch  7500/15000)
Train acc:    0.821 (batch  8000/15000)
Train acc:    0.824 (batch  8500/15000)
Train acc:    0.828 (batch  9000/15000)
Train acc:    0.831 (batch  9500/15000)
Train acc:    0.834 (batch 10000/15000)
Train acc:    0.836 (batch 10500/15000)
Train acc:    0.839 (batch 11000/15000)
Train acc:    0.841 (batch 11500/15000)
Train acc:    0.843 (batch 12000/15000)
Train acc:    0.845 (batch 12500

KeyboardInterrupt: ignored