In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time 
from torch import utils
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

import torchtext
from tqdm import tqdm
from torchtext.datasets import IMDB

from torchtext.datasets.imdb import NUM_LINES
from torchtext.data import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.functional import to_map_style_dataset

import os
import sys
import logging
import logging
logging.basicConfig(
    level=logging.WARN, stream=sys.stdout, \
    format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")

VOCAB_SIZE = 15000
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.warning(f"device is {device}")
VOCAB_SIZE = 15000 
BATCH_SIZE = 64

def yeild_tokens(train_data_iter, tokenizer):
    for i, sample in enumerate(train_data_iter):
        label, comment = sample
        yield tokenizer(comment)
        
class TextClassificationModel(nn.Module):
    def __init__(self, vocab_size=VOCAB_SIZE, embed_dim=64, num_class=2):
        super(TextClassificationModel, self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False).to(device)
        self.fc = nn.Linear(embed_dim, num_class).to(device)
        
    def forward(self, token_index):
        embedded = self.embedding(token_index)  
        return self.fc(embedded)



In [2]:
train_data_iter = IMDB(root="data", split="train") 
tokenizer = get_tokenizer("basic_english")

vocab = build_vocab_from_iterator(yeild_tokens(train_data_iter, tokenizer), min_freq=20, specials=["<unk>"])
vocab.set_default_index(0)  
print(f'Vocal size: {len(vocab)}')

Vocal size: 13351


In [3]:
def collate_fn(batch):
    target = []
    token_index = []
    max_length = 0  
    for i, (label, comment) in enumerate(batch):
        tokens = tokenizer(comment)
        token_index.append(vocab(tokens)) 
        
        if len(tokens) > max_length:
            max_length = len(tokens)
        
        if label == "pos":
            target.append(0)
        else:
            target.append(1)

    token_index = [index + [0]*(max_length-len(index)) for index in token_index]
    return (torch.tensor(target).to(torch.int64), torch.tensor(token_index).to(torch.int32))

In [4]:
def train(train_data_loader, eval_data_loader, model, optimizer, num_epoch, log_step_interval, save_step_interval, \
          eval_step_interval, save_path, resume=""):
    start_epoch = 0
    start_step = 0
    if resume != "":
        logging.warning(f"loading from {resume}")
        checkpoint = torch.load(resume)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        start_step = checkpoint['step']
        
    for epoch_index in tqdm(range(start_epoch, num_epoch), desc="epoch"):
        ema_loss = 0
        num_batches = len(train_data_loader)
        start_time = time.time()
        
        for batch_index, (target, token_index) in enumerate(train_data_loader):
            target = target.to(device)
            token_index = token_index.to(device)
            optimizer.zero_grad()
            step = num_batches*(epoch_index) + batch_index + 1
            logits = model(token_index)

            bce_loss = F.binary_cross_entropy(torch.sigmoid(logits), F.one_hot(target, num_classes=2).to(torch.float32))
            ema_loss = 0.9 * ema_loss + 0.1 * bce_loss  
            bce_loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 0.1)  
            optimizer.step()   
            
            if step % log_step_interval == 0:
                logging.warning(f"epoch_index: {epoch_index}, batch_index: {batch_index}, ema_loss: {ema_loss}")
                
            if step % save_step_interval == 0:
                os.makedirs(save_path, exist_ok=True)
                save_file = os.path.join(save_path, f"step_{step}.pt")
                torch.save({
                    "epoch": epoch_index,
                    "step": step,
                    "model_state_dict": model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': bce_loss,
                }, save_file)
                logging.warning(f"checkpoint has been saved in {save_file}")
            
            if step % eval_step_interval == 0:
                logging.warning("start to do evaluation...")
                model.eval()
                ema_eval_loss = 0
                total_acc_account = 0
                total_account = 0
                for eval_batch_index, (eval_target, eval_token_index) in enumerate(eval_data_loader):
                    eval_target = eval_target.to(device)
                    eval_token_index = eval_token_index.to(device)
                    total_account += eval_target.shape[0]
                    eval_logits = model(eval_token_index)
                    total_acc_account += (torch.argmax(eval_logits, dim=-1) == eval_target).sum().item()
                    eval_bce_loss = F.binary_cross_entropy(torch.sigmoid(eval_logits), F.one_hot(eval_target, num_classes=2).to(torch.float32))
                    ema_eval_loss = 0.9 * ema_eval_loss + 0.1 * eval_bce_loss
                logging.warning(f"ema_eval_loss: {ema_eval_loss}, eval_acc: {total_acc_account / total_account}")
                model.train()
        time_period = time.time()-start_time
        logging.warning(f"time cost is: {time_period}")

In [5]:
model = TextClassificationModel().to(device)
print("Model total parameters:", sum(p.numel() for p in model.parameters()))
resume = ""

Model total parameters: 960130


In [6]:
train_data_iter = IMDB(root="data", split="train") 
train_data_loader = torch.utils.data.DataLoader(
    to_map_style_dataset(train_data_iter), batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)

eval_data_iter = IMDB(root="data", split="test")  
eval_data_loader = utils.data.DataLoader(
    to_map_style_dataset(eval_data_iter), batch_size=8, collate_fn=collate_fn)

In [7]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
train(train_data_loader, eval_data_loader, model, optimizer, num_epoch=10, log_step_interval=20, \
      save_step_interval=500, eval_step_interval=300, save_path="./embed_adm", resume=resume)

epoch:   0%|          | 0/10 [00:00<?, ?it/s]



epoch:  10%|█         | 1/10 [00:10<01:38, 10.91s/it]



epoch:  20%|██        | 2/10 [00:21<01:27, 10.98s/it]



epoch:  30%|███       | 3/10 [00:32<01:16, 10.99s/it]



epoch:  40%|████      | 4/10 [00:49<01:19, 13.21s/it]



epoch:  50%|█████     | 5/10 [01:01<01:03, 12.69s/it]



epoch:  60%|██████    | 6/10 [01:12<00:48, 12.16s/it]



epoch:  70%|███████   | 7/10 [01:29<00:41, 13.86s/it]



epoch:  80%|████████  | 8/10 [01:41<00:26, 13.05s/it]



epoch:  90%|█████████ | 9/10 [01:52<00:12, 12.44s/it]



epoch: 100%|██████████| 10/10 [02:09<00:00, 12.91s/it]


In [8]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
train(train_data_loader, eval_data_loader, model, optimizer, num_epoch=10, log_step_interval=20, \
      save_step_interval=500, eval_step_interval=300, save_path="./embed_admw", resume=resume)

epoch:   0%|          | 0/10 [00:00<?, ?it/s]



epoch:  10%|█         | 1/10 [00:10<01:38, 10.92s/it]



epoch:  20%|██        | 2/10 [00:21<01:27, 10.99s/it]



epoch:  30%|███       | 3/10 [00:32<01:16, 10.99s/it]



epoch:  40%|████      | 4/10 [00:49<01:19, 13.22s/it]



epoch:  50%|█████     | 5/10 [01:00<01:02, 12.41s/it]



epoch:  60%|██████    | 6/10 [01:11<00:47, 11.91s/it]



epoch:  70%|███████   | 7/10 [01:28<00:40, 13.47s/it]



epoch:  80%|████████  | 8/10 [01:39<00:25, 12.72s/it]



epoch:  90%|█████████ | 9/10 [01:50<00:12, 12.22s/it]



epoch: 100%|██████████| 10/10 [02:07<00:00, 12.72s/it]


In [9]:
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
train(train_data_loader, eval_data_loader, model, optimizer, num_epoch=10, log_step_interval=20, \
      save_step_interval=500, eval_step_interval=300, save_path="./embed_sgd", resume=resume)

epoch:   0%|          | 0/10 [00:00<?, ?it/s]



epoch:  10%|█         | 1/10 [00:10<01:37, 10.82s/it]



epoch:  20%|██        | 2/10 [00:21<01:27, 10.92s/it]



epoch:  30%|███       | 3/10 [00:32<01:16, 10.88s/it]



epoch:  40%|████      | 4/10 [00:49<01:18, 13.07s/it]



epoch:  50%|█████     | 5/10 [01:00<01:01, 12.30s/it]



epoch:  60%|██████    | 6/10 [01:10<00:47, 11.82s/it]



epoch:  70%|███████   | 7/10 [01:27<00:40, 13.48s/it]



epoch:  80%|████████  | 8/10 [01:38<00:25, 12.73s/it]



epoch:  90%|█████████ | 9/10 [01:49<00:12, 12.19s/it]



epoch: 100%|██████████| 10/10 [02:06<00:00, 12.66s/it]
