<a href="https://colab.research.google.com/github/gfx73/PML-DL/blob/main/Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install allennlp



In [2]:
!pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 torchdata==0.4.1 torchtext==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu113

Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu113


In [3]:
import sys
CLASSIFIER_PRETRAINED = False
MAX_SEQ_LEN = 1000

# PRECOMPUTE_TOK_IDS = False
# TOK_IDS_PRECOMPUTED = True
IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
  PATH_TO_SAVE_ELMO_CLASSIFIER = '/content/drive/MyDrive/PML&DL/Assignment2/elmo_classifier.pt'
else:
  PATH_TO_SAVE_ELMO_CLASSIFIER = 'elmo_classifier.pt'

In [4]:
if IN_COLAB:
  from google.colab import drive
  drive.mount('/content/drive')

In [5]:
from torchtext.datasets import IMDB

IMDB_train_iter, IMDB_test_iter = IMDB()

In [6]:
from tqdm import tqdm
from torchtext.data.utils import get_tokenizer
import gc
import random


random.seed(11)
tokenizer = get_tokenizer('basic_english')

def get_labels_and_text(datasplit):
  tokens, labels = [], []
  for label, text in tqdm(datasplit):
    tokens.append(tokenizer(text))
    labels.append(label=='pos')
  return tokens, labels

train_tokens, train_labels = get_labels_and_text(IMDB_train_iter)
test_tokens, test_labels = get_labels_and_text(IMDB_test_iter)

sample_tokens_and_labels = lambda tokens, labels: zip(*random.sample(list(zip(tokens, labels)), len(labels)))

train_tokens, train_labels = sample_tokens_and_labels(train_tokens, train_labels)
test_tokens, test_labels = sample_tokens_and_labels(test_tokens, test_labels)

# val_tokens, val_labels = test_tokens[:500], test_labels[:500]
# test_tokens, test_labels = test_tokens[500:], test_labels[:500]

del IMDB_train_iter
del IMDB_test_iter
gc.collect()

25000it [00:03, 6374.72it/s]
25000it [00:03, 6505.10it/s]


0

In [7]:
from torch.utils.data import Dataset, DataLoader, Subset
from allennlp.modules.elmo import Elmo, batch_to_ids
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"

# if PRECOMPUTE_TOK_IDS:
#   def save_tok_ids(all_tokens, filename_prefix, shard_size=5000):
#     all_tok_ids = []
#     for idx, tokens in tqdm(enumerate(all_tokens), total=len(all_tokens)):
#       tok_ids = (batch_to_ids([tokens])[0])
#       all_tok_ids.append(tok_ids)
#
#       if (idx + 1) % shard_size == 0 or (idx + 1) == len(all_tokens):
#         torch.save(all_tok_ids, f"{filename_prefix}{idx // shard_size}.pt")
#         del all_tok_ids
#         gc.collect()
#         all_tok_ids = []
#     return all_tok_ids
#
#   train_filename_prefix = 'train_tok_ids'
#   test_filename_prefix = 'test_tok_ids'
#
#   if not TOK_IDS_PRECOMPUTED:
#     train_tok_ids = save_tok_ids(train_tokens, train_filename_prefix)
#     del train_tokens
#     gc.collect()
#
#     test_tok_ids = save_tok_ids(test_tokens, test_filename_prefix)
#     del test_tokens
#     gc.collect()
#
#
#   class dataset(Dataset):
#     def __init__(self, labels, filename_prefix, max_len, shard_size=5000):
#       self.labels = torch.tensor(labels, dtype=torch.float32)
#       self.length = self.labels.shape[0]
#       self.filename_prefix = filename_prefix
#       self.max_len = max_len
#       self.shard_size=shard_size
#       self.cur_shard = None
#       self.cur_shard_idx = None
#
#     def __getitem__(self, idx):
#       tok_ids = self.__get_tok_ids__(idx).to(device)
#       if tok_ids.shape[0] > self.max_len:
#         tok_ids = tok_ids[:self.max_len,:]
#       else:
#         zeros = torch.zeros((self.max_len - tok_ids.shape[0], 50), dtype=tok_ids.dtype, device=device)
#         tok_ids = torch.concat((tok_ids, zeros))
#
#       return tok_ids.to(device), self.labels[idx].to(device)
#
#     def __get_tok_ids__(self, idx):
#       self.__reload_shard__(idx)
#       return self.cur_shard[idx % self.shard_size]
#
#     def __reload_shard__(self, idx):
#       shard_idx = idx // self.shard_size
#       if self.cur_shard_idx == shard_idx:
#         return
#
#       del self.cur_shard
#       gc.collect()
#       self.cur_shard = torch.load(f"{self.filename_prefix}{shard_idx}.pt")
#       self.cur_shard_idx = shard_idx
#
#     def __len__(self):
#       return self.length
#
#
#   trainset = dataset(train_labels, train_filename_prefix, MAX_SEQ_LEN)
#   testset = dataset(test_labels, test_filename_prefix, MAX_SEQ_LEN)
# else:
class dataset(Dataset):
  def __init__(self, tokens, labels):
    self.tokens = tokens
    self.labels = torch.tensor(labels, dtype=torch.float32)
    self.length = self.labels.shape[0]

  def __getitem__(self, idx):
    return self.tokens[idx], self.labels[idx]

  def __len__(self):
    return self.length


trainset = dataset(train_tokens, train_labels)
testset = dataset(test_tokens, test_labels)

valset_size = int(len(testset) * 0.02)
testset_size = len(testset) - valset_size
valset = Subset(testset, range(valset_size))
testset = Subset(testset, range(valset_size, valset_size + testset_size))

class CollateBatch(object):
  def __init__(self, batch_to_ids):
    self.batch_to_ids = batch_to_ids

  def __call__(self, batch):
    tokens_batch, labels_batch = [tokens_and_label[0] for tokens_and_label in batch], [tokens_and_label[1] for tokens_and_label in batch]
    tok_ids = self.batch_to_ids(tokens_batch).to(device)
    labels_batch = torch.tensor(labels_batch, dtype=torch.float32, device=device)
    return tok_ids, labels_batch

collateBatch = CollateBatch(batch_to_ids)

BATCH_SIZE = 6
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collateBatch)
valloader = DataLoader(valset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collateBatch)
testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collateBatch)

In [8]:
from torch import nn
from torch.nn import functional as F

class Classifier(nn.Module):
  def __init__(self, embed_size, elmo):
    super(Classifier, self).__init__()
    self.embed_size = embed_size
    self.elmo = elmo
    self.fc1 = nn.Linear(embed_size, 1)
    
  def forward(self, input):
    embs = self.elmo(input)['elmo_representations'][0]
    mean = embs.mean(dim=1)
    x = torch.sigmoid(self.fc1(mean))
    return x


if CLASSIFIER_PRETRAINED:
  classifier = torch.load(PATH_TO_SAVE_ELMO_CLASSIFIER)
else:
  if IN_COLAB:
    options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json"
    weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5"
  else:
    options_file = "options.json"
    weight_file = "weights.hdf5"
  elmo = Elmo(options_file, weight_file, dropout=0, requires_grad=False, num_output_representations=1).to(device)
  classifier = Classifier(1024, elmo=elmo).to(device)

In [9]:
sum(p.numel() for p in classifier.elmo.parameters() if p.requires_grad)

4

In [10]:
learning_rate = 0.001
epochs = 1
# l2_penalty = 0.001
l2_penalty = 0

optimizer = torch.optim.RMSprop(classifier.parameters(), lr=learning_rate, weight_decay=l2_penalty)
loss_fn = F.binary_cross_entropy_with_logits

In [11]:
!pip install torchmetrics



In [12]:
import torchmetrics

def eval_model(model, data, loss_fn):
  acc_metric = torchmetrics.Accuracy().to(device)
  prec_metric = torchmetrics.Precision().to(device)
  rec_metric = torchmetrics.Recall().to(device)
  f1_metric = torchmetrics.F1Score().to(device)
  running_loss = 0
  for x, y in tqdm(data):
    with torch.no_grad():
      y = y.reshape(-1, 1)
      with torch.autocast(device_type=device, dtype=torch.float16):
        preds = model(x)
        loss = loss_fn(preds, y)


      running_loss += loss.item()
      
      y = y.type(torch.int8)
      acc_metric(preds.round(), y)
      prec_metric(preds.round(), y)
      rec_metric(preds.round(), y)
      f1_metric(preds.round(), y)

      # print(y)
      # print(preds.round())
      # print(acc_metric.compute())

  loss = running_loss / len(data)
  acc = acc_metric.compute().item()
  prec = prec_metric.compute().item()
  rec = rec_metric.compute().item()
  f1 = f1_metric.compute().item()
  return loss, acc, prec, rec, f1

# loss, acc, prec, rec, f1 = eval_model(classifier, valloader, loss_fn)
# print("Initial metrics\tval loss: {}\tval acc: {}\tval prec: {}\tval rec: {}\tval f1: {}".format(loss, acc, prec, rec, f1))

In [None]:
torch.cuda.empty_cache()
# torch.autograd.set_detect_anomaly(True)
train_losses = []
train_accs = []
val_metrics = []

best_val_loss = 1e+8
for epoch in range(epochs):
  running_loss, correct, total = 0, 0, 0
  for iteration, (x_train ,y_train) in tqdm(enumerate(trainloader), total=len(trainloader)):
    optimizer.zero_grad()
    y_train = y_train.reshape(-1,1)
    with torch.autocast(device_type=device, dtype=torch.float16):
      preds = classifier(x_train)
      loss = loss_fn(preds, y_train)

    running_loss += loss.item()
    total += y_train.shape[0]
    correct += preds.round().eq(y_train).sum().item()

    loss.backward()
    optimizer.step()

    if iteration % 50 == 0:
      _loss = running_loss / (iteration + 1)
      acc = correct / total
      print("epoch: {}\titeration: {}\tloss: {}\tthis iteration loss: {}\taccuracy: {}".format(epoch, iteration, _loss, loss, acc))


  loss = running_loss / len(trainloader)
  acc = correct / total
  train_losses.append(loss)
  train_accs.append(acc)
  print("epoch {}\ttrain loss : {}\ttrain accuracy : {}".format(epoch, loss, acc))

  loss, acc, prec, rec, f1 = eval_model(classifier, valloader, loss_fn)
  val_metrics.append([loss, acc, prec, rec, f1])
  print("epoch: {}\tval loss: {}\tval acc: {}\tval prec: {}\tval rec: {}\tval f1: {}".format(epoch, loss, acc, prec, rec, f1))
  if best_val_loss > loss:
    torch.save(classifier, PATH_TO_SAVE_ELMO_CLASSIFIER)
    best_val_loss = loss
  if not IN_COLAB:
    torch.save(classifier, f'classifier{epoch}.pt')

  0%|          | 1/4167 [00:12<14:03:36, 12.15s/it]

epoch: 0	iteration: 0	loss: 0.977137565612793	this iteration loss: 0.977137565612793	accuracy: 0.6666666666666666


  1%|          | 51/4167 [05:31<7:59:54,  7.00s/it] 

epoch: 0	iteration: 50	loss: 0.7061653815063775	this iteration loss: 0.7873473167419434	accuracy: 0.49019607843137253


  2%|▏         | 101/4167 [10:31<6:29:09,  5.74s/it]

epoch: 0	iteration: 100	loss: 0.6983277284272826	this iteration loss: 0.7133684158325195	accuracy: 0.504950495049505


  4%|▎         | 151/4167 [15:55<5:35:05,  5.01s/it] 

epoch: 0	iteration: 150	loss: 0.6946744899086604	this iteration loss: 0.733524739742279	accuracy: 0.5077262693156733


  5%|▍         | 201/4167 [20:50<7:45:53,  7.05s/it]

epoch: 0	iteration: 200	loss: 0.6915757709474706	this iteration loss: 0.6691860556602478	accuracy: 0.511608623548922


  6%|▌         | 251/4167 [25:47<5:54:18,  5.43s/it]

epoch: 0	iteration: 250	loss: 0.6892048735542601	this iteration loss: 0.6892566680908203	accuracy: 0.5239043824701195


  7%|▋         | 301/4167 [31:09<6:46:45,  6.31s/it]

epoch: 0	iteration: 300	loss: 0.6858527583141264	this iteration loss: 0.5863710641860962	accuracy: 0.5332225913621262


  8%|▊         | 351/4167 [36:45<6:09:02,  5.80s/it]

epoch: 0	iteration: 350	loss: 0.6805021595581304	this iteration loss: 0.6318652629852295	accuracy: 0.5489078822412156


 10%|▉         | 401/4167 [41:50<7:33:28,  7.22s/it]

epoch: 0	iteration: 400	loss: 0.6760530762392981	this iteration loss: 0.6494297981262207	accuracy: 0.5631753948462178


 11%|█         | 451/4167 [46:35<4:45:15,  4.61s/it]

epoch: 0	iteration: 450	loss: 0.6735029586667232	this iteration loss: 0.7610359191894531	accuracy: 0.574648928307465


 12%|█▏        | 501/4167 [51:27<7:12:32,  7.08s/it]

epoch: 0	iteration: 500	loss: 0.6732975340888886	this iteration loss: 0.635772705078125	accuracy: 0.5798403193612774


 13%|█▎        | 551/4167 [56:21<7:02:53,  7.02s/it]

epoch: 0	iteration: 550	loss: 0.6700989151823109	this iteration loss: 0.5313658118247986	accuracy: 0.5831820931639443


 14%|█▍        | 601/4167 [1:01:44<6:35:38,  6.66s/it]

epoch: 0	iteration: 600	loss: 0.6695641694767106	this iteration loss: 0.8004767298698425	accuracy: 0.5904048807542984


 16%|█▌        | 651/4167 [1:06:58<6:06:55,  6.26s/it]

epoch: 0	iteration: 650	loss: 0.6680788634467967	this iteration loss: 0.6173998713493347	accuracy: 0.601126472094214


 17%|█▋        | 701/4167 [1:11:32<5:24:04,  5.61s/it]

epoch: 0	iteration: 700	loss: 0.6660108682347432	this iteration loss: 0.6577993035316467	accuracy: 0.6077032810271041


 18%|█▊        | 751/4167 [1:16:30<5:24:43,  5.70s/it]

epoch: 0	iteration: 750	loss: 0.6642427365884641	this iteration loss: 0.552340030670166	accuracy: 0.6154016866400355


 19%|█▉        | 801/4167 [1:22:10<6:51:45,  7.34s/it]

epoch: 0	iteration: 800	loss: 0.6635672919610318	this iteration loss: 0.7144010066986084	accuracy: 0.6186017478152309


 20%|██        | 851/4167 [1:27:36<7:41:31,  8.35s/it]

epoch: 0	iteration: 850	loss: 0.662641859012541	this iteration loss: 0.5652734041213989	accuracy: 0.6214257735996866


 22%|██▏       | 901/4167 [1:32:48<5:56:03,  6.54s/it]

epoch: 0	iteration: 900	loss: 0.6610860550046364	this iteration loss: 0.651819109916687	accuracy: 0.6267110617832038


 23%|██▎       | 951/4167 [1:37:39<5:47:17,  6.48s/it]

epoch: 0	iteration: 950	loss: 0.6604312079812701	this iteration loss: 0.7882561087608337	accuracy: 0.632316859446197


 24%|██▍       | 1001/4167 [1:42:47<5:21:48,  6.10s/it]

epoch: 0	iteration: 1000	loss: 0.6595263288214014	this iteration loss: 0.6595432758331299	accuracy: 0.6350316350316351


 25%|██▌       | 1051/4167 [1:48:02<4:24:32,  5.09s/it]

epoch: 0	iteration: 1050	loss: 0.6583918880894567	this iteration loss: 0.6944688558578491	accuracy: 0.6374881065651761


 26%|██▋       | 1101/4167 [1:52:59<5:17:09,  6.21s/it]

epoch: 0	iteration: 1100	loss: 0.6564756291330565	this iteration loss: 0.621554970741272	accuracy: 0.641537995761429


 28%|██▊       | 1151/4167 [1:58:27<5:58:27,  7.13s/it]

epoch: 0	iteration: 1150	loss: 0.6549104529034045	this iteration loss: 0.5787568092346191	accuracy: 0.6471184477266145


 29%|██▉       | 1201/4167 [2:03:17<4:30:40,  5.48s/it]

epoch: 0	iteration: 1200	loss: 0.6547977781911178	this iteration loss: 0.8058728575706482	accuracy: 0.6520954759922287


 30%|███       | 1251/4167 [2:08:38<5:34:47,  6.89s/it]

epoch: 0	iteration: 1250	loss: 0.6536625624179458	this iteration loss: 0.6596800684928894	accuracy: 0.6562749800159872


 31%|███       | 1301/4167 [2:15:10<5:43:10,  7.18s/it] 

epoch: 0	iteration: 1300	loss: 0.6515640623976321	this iteration loss: 0.5570120215415955	accuracy: 0.6598770176787087


 32%|███▏      | 1351/4167 [2:20:30<4:05:04,  5.22s/it] 

epoch: 0	iteration: 1350	loss: 0.65057745650907	this iteration loss: 0.5762374401092529	accuracy: 0.6633358006415001


 34%|███▎      | 1401/4167 [2:25:29<4:23:24,  5.71s/it]

epoch: 0	iteration: 1400	loss: 0.6491626889938801	this iteration loss: 0.5686429738998413	accuracy: 0.6682131810611468


 35%|███▍      | 1451/4167 [2:34:50<4:24:46,  5.85s/it] 

epoch: 0	iteration: 1450	loss: 0.6486556313926807	this iteration loss: 0.5552093982696533	accuracy: 0.6714909257983


 36%|███▌      | 1501/4167 [2:39:47<4:54:45,  6.63s/it]

epoch: 0	iteration: 1500	loss: 0.6482674967321375	this iteration loss: 0.4611329734325409	accuracy: 0.6742171885409727


 37%|███▋      | 1551/4167 [2:44:52<3:44:18,  5.14s/it]

epoch: 0	iteration: 1550	loss: 0.6474197392575899	this iteration loss: 0.621996283531189	accuracy: 0.6761229314420804


 38%|███▊      | 1601/4167 [2:50:27<4:11:34,  5.88s/it]

epoch: 0	iteration: 1600	loss: 0.646567492578865	this iteration loss: 0.5352944135665894	accuracy: 0.6782219446179472


 40%|███▉      | 1651/4167 [2:56:03<5:07:46,  7.34s/it]

epoch: 0	iteration: 1650	loss: 0.6455045433783806	this iteration loss: 0.5609873533248901	accuracy: 0.6813042600444176


 41%|████      | 1701/4167 [3:00:47<5:11:41,  7.58s/it]

epoch: 0	iteration: 1700	loss: 0.6441761740406144	this iteration loss: 0.7271756529808044	accuracy: 0.6837154614932393


 42%|████▏     | 1751/4167 [3:05:30<4:18:54,  6.43s/it]

epoch: 0	iteration: 1750	loss: 0.6436686617418537	this iteration loss: 0.675286054611206	accuracy: 0.6857985912811727


 43%|████▎     | 1801/4167 [3:11:27<4:57:47,  7.55s/it]

epoch: 0	iteration: 1800	loss: 0.6428145585888826	this iteration loss: 0.7485710978507996	accuracy: 0.686933185267444


 44%|████▍     | 1851/4167 [3:16:11<3:30:33,  5.45s/it]

epoch: 0	iteration: 1850	loss: 0.6419419270151696	this iteration loss: 0.7201981544494629	accuracy: 0.6893571042679633


 46%|████▌     | 1901/4167 [3:21:07<3:23:38,  5.39s/it]

epoch: 0	iteration: 1900	loss: 0.6414879697175605	this iteration loss: 0.6919854879379272	accuracy: 0.6906891109942136


 47%|████▋     | 1951/4167 [3:26:27<3:23:47,  5.52s/it]

epoch: 0	iteration: 1950	loss: 0.6406565554504942	this iteration loss: 0.5018265843391418	accuracy: 0.692636254912011


 48%|████▊     | 2001/4167 [3:31:23<2:50:28,  4.72s/it]

epoch: 0	iteration: 2000	loss: 0.6398154963617739	this iteration loss: 0.5308161973953247	accuracy: 0.6957354656005331


 49%|████▉     | 2051/4167 [3:36:35<2:48:38,  4.78s/it]

epoch: 0	iteration: 2050	loss: 0.6396172885805267	this iteration loss: 0.6471821069717407	accuracy: 0.6973833902161547


 50%|█████     | 2101/4167 [3:41:25<3:05:11,  5.38s/it]

epoch: 0	iteration: 2100	loss: 0.6388627196538454	this iteration loss: 0.6997742652893066	accuracy: 0.6996668253212756


 52%|█████▏    | 2151/4167 [3:46:28<3:12:42,  5.74s/it]

epoch: 0	iteration: 2150	loss: 0.638046925647599	this iteration loss: 0.5832899212837219	accuracy: 0.7016116534944987


 53%|█████▎    | 2201/4167 [3:51:44<3:03:01,  5.59s/it]

epoch: 0	iteration: 2200	loss: 0.6371712979134079	this iteration loss: 0.7232033610343933	accuracy: 0.7037710131758291


 54%|█████▍    | 2251/4167 [3:56:50<3:10:26,  5.96s/it]

epoch: 0	iteration: 2250	loss: 0.6364424679608729	this iteration loss: 0.662787139415741	accuracy: 0.7056863616170591


 55%|█████▌    | 2301/4167 [4:01:27<2:43:48,  5.27s/it]

epoch: 0	iteration: 2300	loss: 0.6358905632513495	this iteration loss: 0.6066322922706604	accuracy: 0.7072287411270463


 56%|█████▋    | 2351/4167 [4:07:20<3:31:03,  6.97s/it]

epoch: 0	iteration: 2350	loss: 0.6351730493929274	this iteration loss: 0.6397579908370972	accuracy: 0.7092726499361973


 58%|█████▊    | 2401/4167 [4:13:37<3:00:21,  6.13s/it] 

epoch: 0	iteration: 2400	loss: 0.6346569330133234	this iteration loss: 0.7373355627059937	accuracy: 0.7106761071775649


 59%|█████▉    | 2451/4167 [4:18:51<2:36:35,  5.48s/it]

epoch: 0	iteration: 2450	loss: 0.6342481546793992	this iteration loss: 0.6484643816947937	accuracy: 0.7118183054535564


 60%|██████    | 2501/4167 [4:23:47<2:00:59,  4.36s/it]

epoch: 0	iteration: 2500	loss: 0.6332544677975368	this iteration loss: 0.46458810567855835	accuracy: 0.713181394109023


 61%|██████    | 2551/4167 [4:28:21<2:35:44,  5.78s/it]

epoch: 0	iteration: 2550	loss: 0.6322860704572095	this iteration loss: 0.5402002334594727	accuracy: 0.7148177185417484


 62%|██████▏   | 2601/4167 [4:33:11<2:48:53,  6.47s/it]

epoch: 0	iteration: 2600	loss: 0.6317365875339471	this iteration loss: 0.7020460367202759	accuracy: 0.7166474432910419


 64%|██████▎   | 2651/4167 [4:38:21<2:07:38,  5.05s/it]

epoch: 0	iteration: 2650	loss: 0.6309566563118273	this iteration loss: 0.4408737123012543	accuracy: 0.7183452785112536


 64%|██████▍   | 2657/4167 [4:39:04<3:05:09,  7.36s/it]

In [None]:
if IN_COLAB:
  !kill $(ps aux | awk '{print $2}')