# Use Glove in Pytorch to Finish NLP task

Author:Huihan Yang

# Preparation

In [2]:
%cd ..

d:\杨蕙菡\assignment-2-text-classification-foxintohumanbeing


In [3]:
import torchtext
import os
import numpy as np
import random
import torch
from torch.utils.data import DataLoader, Dataset
from torchtext.data import get_tokenizer
from torchtext.vocab import GloVe
from torch.nn.utils.rnn import pad_sequence
from torch import nn
import pandas as pd
import argparse
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

## Set random seed

In [4]:
seed = 114
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x16c4501f490>

# Hyperparameter

In [5]:
configs = {
    'work_dir': 'work_dir2', 
    'device': 'cuda:0',
    'batch': 32, 
    'optimizer_config': {
        'lr': 1e-4, 
    }, 
    'epoch': 100, 
    'model_name': 'bert-large-uncased',
    'enable_tb': False,
    'dropout':0.5
}

In [6]:
GLOVE_DIM = 100
GLOVE = GloVe(name='6B', dim=GLOVE_DIM)

# DataLoader

In [7]:
class TWITTERDataset(Dataset):

    def __init__(self, fname, is_train=True):
        super().__init__()
        self.tokenizer = get_tokenizer('basic_english')
        self.train = is_train
        if is_train == True:
            df = pd.read_csv(fname).iloc[:,1:]
        else:
            df = pd.read_csv(fname)
        self.lines = []
        for i in range(len(df)):
            if is_train == True:
                self.lines.append((df.iloc[i, 0], df.iloc[i, 1], df.iloc[i, 2], GLOVE.get_vecs_by_tokens(self.tokenizer(df.iloc[i, 3])), torch.tensor(df.iloc[i, 4], dtype=torch.int32) ))
            else:
                self.lines.append(( GLOVE.get_vecs_by_tokens(self.tokenizer(df.iloc[i, 3]))))
        print('Complete data preprocessing with length:', len(self.lines)) 

    def __len__(self):
        return len(self.lines)

    def __getitem__(self, index: int):
        item = self.lines[index]
        if self.train:
            return item[3], item[4]
        return item


def get_dataloader():
    def collate_fn1(batch):
        x, y = zip(*batch)
        x_pad = pad_sequence(x, batch_first=True)
        y = torch.Tensor(y)
        return x_pad, y

    def collate_fn2(batch):
        x = [item for item in batch]
        x_pad = pad_sequence(x, batch_first=True)
        return x_pad

    train_dataloader = DataLoader(TWITTERDataset('nlp-getting-started/train_clean.csv'),
                    batch_size = configs['batch'],
                    shuffle = True,
                    collate_fn = collate_fn1)
    val_dataloader = DataLoader(TWITTERDataset('nlp-getting-started/val_clean.csv'),
                    batch_size = configs['batch'],
                    shuffle = True,
                    collate_fn = collate_fn1)
    test_dataloader = DataLoader(TWITTERDataset('nlp-getting-started/test.csv', False),
                    batch_size = configs['batch'],
                    shuffle = False,
                    collate_fn = collate_fn2)
    return train_dataloader,val_dataloader, test_dataloader


train_dataloader,val_dataloader, test_dataloader = get_dataloader()


Complete data preprocessing with length: 5329
Complete data preprocessing with length: 2284
Complete data preprocessing with length: 3263


# Define Model

In [8]:
device = configs['device']
class RNN(torch.nn.Module):
    def __init__(self, hidden_units=64, dropout_rate=configs['dropout']):
        super().__init__()
        self.drop = nn.Dropout(dropout_rate)
        self.rnn = nn.GRU(GLOVE_DIM, hidden_units, 1, batch_first=True)
        self.linear = nn.Linear(hidden_units, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor):
        # x shape: [batch, max_word_length, embedding_length]
        emb = self.drop(x)
        output, _ = self.rnn(emb)
        output = output[:, -1]
        output = self.linear(output)
        output = self.sigmoid(output)
        return output

model = RNN().to(device)

# Trainning

In [12]:
import os
os.getcwd()

'd:\\杨蕙菡\\assignment-2-text-classification-foxintohumanbeing'

In [13]:
writer = SummaryWriter('1141514')
optimizer = torch.optim.Adam(model.parameters(), lr=configs['optimizer_config']['lr'])
citerion = torch.nn.BCELoss()
os.makedirs('work_dir', exist_ok=True)
best_accuracy = 0
for epoch in range(configs['epoch']):
    model.train() # Set the model to training mode
    loss_sum = 0
    dataset_len = len(train_dataloader.dataset)
    for x, y in tqdm(train_dataloader):
        batchsize = y.shape[0]
        x = x.to(device)
        y = y.to(device)
        hat_y = model(x)
        hat_y = hat_y.squeeze(-1)
        loss = citerion(hat_y, y)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        loss_sum += loss * batchsize
    writer.add_scalar('training loss',
                            loss_sum/dataset_len,
                            epoch)

    model.eval()  # Set the model to evaluation mode
    correct_predictions = 0
    total_predictions = 0
    test_loss = 0
    results_predict = []
    with torch.no_grad():
        for x, y in tqdm(val_dataloader):
            x = x.to(device)
            y = y.to(device)

            hat_y = model(x)
            hat_y = hat_y.squeeze(-1)

            loss = citerion(hat_y, y)
            test_loss += loss.item() * y.size(0)

            # Calculate accuracy
            predictions = (hat_y > 0.5).int()  # Convert probabilities to binary predictions
            correct_predictions += (predictions == y).sum().item()
            total_predictions += y.size(0)
            results_predict.append(predictions.cpu())

    accuracy = correct_predictions / total_predictions
    avg_test_loss = test_loss / total_predictions
    writer.add_scalar('average validation accuracy', accuracy, epoch)

    if accuracy > best_accuracy:
        best_accuracy = accuracy
        pt_path = os.path.join('baseline', 'best+original.pt')
        torch.save(model.state_dict(), pt_path)
        print('save model')
        results_predict = torch.concat(results_predict).tolist()
        id = pd.read_csv('nlp-getting-started/val_clean.csv')['id']
        prediction = pd.DataFrame()
        prediction['id'] = id.values
        prediction['target'] = results_predict
        prediction.to_csv('baseline/validation_result.csv',index=False)
        

    print(f'Epoch {epoch}. accuracy: {accuracy}')

100%|██████████| 167/167 [00:00<00:00, 353.26it/s]
100%|██████████| 72/72 [00:00<00:00, 1002.68it/s]


save model
Epoch 0. accuracy: 0.5647985989492119


100%|██████████| 167/167 [00:00<00:00, 382.24it/s]
100%|██████████| 72/72 [00:00<00:00, 759.93it/s]


Epoch 1. accuracy: 0.5647985989492119


100%|██████████| 167/167 [00:00<00:00, 370.46it/s]
100%|██████████| 72/72 [00:00<00:00, 880.40it/s]


Epoch 2. accuracy: 0.5647985989492119


100%|██████████| 167/167 [00:00<00:00, 343.13it/s]
100%|██████████| 72/72 [00:00<00:00, 999.73it/s]


save model
Epoch 3. accuracy: 0.6987740805604203


100%|██████████| 167/167 [00:00<00:00, 403.77it/s]
100%|██████████| 72/72 [00:00<00:00, 991.77it/s]


save model
Epoch 4. accuracy: 0.7342381786339754


100%|██████████| 167/167 [00:00<00:00, 392.41it/s]
100%|██████████| 72/72 [00:00<00:00, 988.95it/s]


save model
Epoch 5. accuracy: 0.7548161120840631


100%|██████████| 167/167 [00:00<00:00, 399.64it/s]
100%|██████████| 72/72 [00:00<00:00, 902.42it/s]


save model
Epoch 6. accuracy: 0.7670753064798599


100%|██████████| 167/167 [00:00<00:00, 342.73it/s]
100%|██████████| 72/72 [00:00<00:00, 913.83it/s]


save model
Epoch 7. accuracy: 0.7732049036777583


100%|██████████| 167/167 [00:00<00:00, 386.71it/s]
100%|██████████| 72/72 [00:00<00:00, 902.42it/s]


save model
Epoch 8. accuracy: 0.7753940455341506


100%|██████████| 167/167 [00:00<00:00, 399.09it/s]
100%|██████████| 72/72 [00:00<00:00, 927.31it/s]


save model
Epoch 9. accuracy: 0.7806479859894921


100%|██████████| 167/167 [00:00<00:00, 392.15it/s]
100%|██████████| 72/72 [00:00<00:00, 732.02it/s]


save model
Epoch 10. accuracy: 0.782399299474606


100%|██████████| 167/167 [00:00<00:00, 357.03it/s]
100%|██████████| 72/72 [00:00<00:00, 880.40it/s]


Epoch 11. accuracy: 0.7802101576182137


100%|██████████| 167/167 [00:00<00:00, 380.17it/s]
100%|██████████| 72/72 [00:00<00:00, 596.64it/s]


save model
Epoch 12. accuracy: 0.7832749562171629


100%|██████████| 167/167 [00:00<00:00, 379.18it/s]
100%|██████████| 72/72 [00:00<00:00, 913.84it/s]


Epoch 13. accuracy: 0.7819614711033275


100%|██████████| 167/167 [00:00<00:00, 395.36it/s]
100%|██████████| 72/72 [00:00<00:00, 880.41it/s]


Epoch 14. accuracy: 0.7806479859894921


100%|██████████| 167/167 [00:00<00:00, 349.58it/s]
100%|██████████| 72/72 [00:00<00:00, 768.01it/s]


Epoch 15. accuracy: 0.7784588441330998


100%|██████████| 167/167 [00:00<00:00, 374.25it/s]
100%|██████████| 72/72 [00:00<00:00, 784.71it/s]


save model
Epoch 16. accuracy: 0.7867775831873906


100%|██████████| 167/167 [00:00<00:00, 393.75it/s]
100%|██████████| 72/72 [00:00<00:00, 937.58it/s]


Epoch 17. accuracy: 0.7775831873905429


100%|██████████| 167/167 [00:00<00:00, 378.74it/s]
100%|██████████| 72/72 [00:00<00:00, 913.84it/s]


Epoch 18. accuracy: 0.7810858143607706


100%|██████████| 167/167 [00:00<00:00, 391.54it/s]
100%|██████████| 72/72 [00:00<00:00, 768.01it/s]


Epoch 19. accuracy: 0.7802101576182137


100%|██████████| 167/167 [00:00<00:00, 358.88it/s]
100%|██████████| 72/72 [00:00<00:00, 913.84it/s]


save model
Epoch 20. accuracy: 0.7915936952714536


100%|██████████| 167/167 [00:00<00:00, 399.62it/s]
100%|██████████| 72/72 [00:00<00:00, 877.19it/s]


Epoch 21. accuracy: 0.7889667250437828


100%|██████████| 167/167 [00:00<00:00, 334.27it/s]
100%|██████████| 72/72 [00:00<00:00, 784.70it/s]


Epoch 22. accuracy: 0.7867775831873906


100%|██████████| 167/167 [00:00<00:00, 386.71it/s]
100%|██████████| 72/72 [00:00<00:00, 732.40it/s]


Epoch 23. accuracy: 0.7863397548161121


100%|██████████| 167/167 [00:00<00:00, 359.33it/s]
100%|██████████| 72/72 [00:00<00:00, 776.27it/s]


Epoch 24. accuracy: 0.782399299474606


100%|██████████| 167/167 [00:00<00:00, 355.53it/s]
100%|██████████| 72/72 [00:00<00:00, 783.34it/s]


Epoch 25. accuracy: 0.787215411558669


100%|██████████| 167/167 [00:00<00:00, 385.41it/s]
100%|██████████| 72/72 [00:00<00:00, 891.27it/s]


Epoch 26. accuracy: 0.7837127845884413


100%|██████████| 167/167 [00:00<00:00, 387.11it/s]
100%|██████████| 72/72 [00:00<00:00, 913.84it/s]


Epoch 27. accuracy: 0.7898423817863398


100%|██████████| 167/167 [00:00<00:00, 388.57it/s]
100%|██████████| 72/72 [00:00<00:00, 913.84it/s]


Epoch 28. accuracy: 0.7863397548161121


100%|██████████| 167/167 [00:00<00:00, 375.81it/s]
100%|██████████| 72/72 [00:00<00:00, 736.66it/s]


Epoch 29. accuracy: 0.7876532399299475


100%|██████████| 167/167 [00:00<00:00, 370.46it/s]
100%|██████████| 72/72 [00:00<00:00, 880.41it/s]


Epoch 30. accuracy: 0.7898423817863398


100%|██████████| 167/167 [00:00<00:00, 336.53it/s]
100%|██████████| 72/72 [00:00<00:00, 811.15it/s]


Epoch 31. accuracy: 0.787215411558669


100%|██████████| 167/167 [00:00<00:00, 394.92it/s]
100%|██████████| 72/72 [00:00<00:00, 869.79it/s]


Epoch 32. accuracy: 0.7837127845884413


100%|██████████| 167/167 [00:00<00:00, 367.55it/s]
100%|██████████| 72/72 [00:00<00:00, 776.27it/s]


Epoch 33. accuracy: 0.7880910683012259


100%|██████████| 167/167 [00:00<00:00, 369.19it/s]
100%|██████████| 72/72 [00:00<00:00, 748.11it/s]


Epoch 34. accuracy: 0.7907180385288967


100%|██████████| 167/167 [00:00<00:00, 356.47it/s]
100%|██████████| 72/72 [00:00<00:00, 891.27it/s]


Epoch 35. accuracy: 0.7885288966725044


100%|██████████| 167/167 [00:00<00:00, 400.59it/s]
100%|██████████| 72/72 [00:00<00:00, 896.10it/s]


Epoch 36. accuracy: 0.7898423817863398


100%|██████████| 167/167 [00:00<00:00, 402.52it/s]
100%|██████████| 72/72 [00:00<00:00, 891.28it/s]


Epoch 37. accuracy: 0.7885288966725044


100%|██████████| 167/167 [00:00<00:00, 393.05it/s]
100%|██████████| 72/72 [00:00<00:00, 721.92it/s]


Epoch 38. accuracy: 0.7898423817863398


100%|██████████| 167/167 [00:00<00:00, 325.14it/s]
100%|██████████| 72/72 [00:00<00:00, 902.41it/s]


Epoch 39. accuracy: 0.7902802101576182


100%|██████████| 167/167 [00:00<00:00, 385.33it/s]
100%|██████████| 72/72 [00:00<00:00, 859.44it/s]


Epoch 40. accuracy: 0.7863397548161121


100%|██████████| 167/167 [00:00<00:00, 361.66it/s]
100%|██████████| 72/72 [00:00<00:00, 891.27it/s]


Epoch 41. accuracy: 0.7889667250437828


100%|██████████| 167/167 [00:00<00:00, 380.08it/s]
100%|██████████| 72/72 [00:00<00:00, 872.76it/s]


Epoch 42. accuracy: 0.7898423817863398


100%|██████████| 167/167 [00:00<00:00, 358.90it/s]
100%|██████████| 72/72 [00:00<00:00, 707.77it/s]


Epoch 43. accuracy: 0.7898423817863398


100%|██████████| 167/167 [00:00<00:00, 342.73it/s]
100%|██████████| 72/72 [00:00<00:00, 759.92it/s]


Epoch 44. accuracy: 0.7859019264448336


100%|██████████| 167/167 [00:00<00:00, 353.26it/s]
100%|██████████| 72/72 [00:00<00:00, 677.57it/s]


Epoch 45. accuracy: 0.7898423817863398


100%|██████████| 167/167 [00:00<00:00, 377.98it/s]
100%|██████████| 72/72 [00:00<00:00, 880.40it/s]


Epoch 46. accuracy: 0.7898423817863398


100%|██████████| 167/167 [00:00<00:00, 373.06it/s]
100%|██████████| 72/72 [00:00<00:00, 876.17it/s]


Epoch 47. accuracy: 0.787215411558669


100%|██████████| 167/167 [00:00<00:00, 374.69it/s]
100%|██████████| 72/72 [00:00<00:00, 732.84it/s]


Epoch 48. accuracy: 0.7915936952714536


100%|██████████| 167/167 [00:00<00:00, 329.06it/s]
100%|██████████| 72/72 [00:00<00:00, 823.00it/s]


Epoch 49. accuracy: 0.7898423817863398


100%|██████████| 167/167 [00:00<00:00, 331.58it/s]
100%|██████████| 72/72 [00:00<00:00, 880.40it/s]


Epoch 50. accuracy: 0.7889667250437828


100%|██████████| 167/167 [00:00<00:00, 368.37it/s]
100%|██████████| 72/72 [00:00<00:00, 880.41it/s]


Epoch 51. accuracy: 0.7902802101576182


100%|██████████| 167/167 [00:00<00:00, 384.05it/s]
100%|██████████| 72/72 [00:00<00:00, 768.01it/s]


save model
Epoch 52. accuracy: 0.792907180385289


100%|██████████| 167/167 [00:00<00:00, 342.02it/s]
100%|██████████| 72/72 [00:00<00:00, 707.77it/s]


Epoch 53. accuracy: 0.792031523642732


100%|██████████| 167/167 [00:00<00:00, 299.39it/s]
100%|██████████| 72/72 [00:00<00:00, 733.98it/s]


Epoch 54. accuracy: 0.787215411558669


100%|██████████| 167/167 [00:00<00:00, 343.13it/s]
100%|██████████| 72/72 [00:00<00:00, 811.15it/s]


save model
Epoch 55. accuracy: 0.7933450087565674


100%|██████████| 167/167 [00:00<00:00, 332.15it/s]
100%|██████████| 72/72 [00:00<00:00, 776.27it/s]


Epoch 56. accuracy: 0.7894045534150613


100%|██████████| 167/167 [00:00<00:00, 317.40it/s]
100%|██████████| 72/72 [00:00<00:00, 671.38it/s]


Epoch 57. accuracy: 0.7902802101576182


100%|██████████| 167/167 [00:00<00:00, 328.33it/s]
100%|██████████| 72/72 [00:00<00:00, 751.22it/s]


Epoch 58. accuracy: 0.7885288966725044


100%|██████████| 167/167 [00:00<00:00, 338.76it/s]
100%|██████████| 72/72 [00:00<00:00, 849.33it/s]


Epoch 59. accuracy: 0.7907180385288967


100%|██████████| 167/167 [00:00<00:00, 370.83it/s]
100%|██████████| 72/72 [00:00<00:00, 849.33it/s]


Epoch 60. accuracy: 0.7880910683012259


100%|██████████| 167/167 [00:00<00:00, 344.13it/s]
100%|██████████| 72/72 [00:00<00:00, 768.01it/s]


Epoch 61. accuracy: 0.7889667250437828


100%|██████████| 167/167 [00:00<00:00, 346.68it/s]
100%|██████████| 72/72 [00:00<00:00, 707.77it/s]


Epoch 62. accuracy: 0.792031523642732


100%|██████████| 167/167 [00:00<00:00, 324.15it/s]
100%|██████████| 72/72 [00:00<00:00, 752.01it/s]


Epoch 63. accuracy: 0.792031523642732


100%|██████████| 167/167 [00:00<00:00, 327.40it/s]
100%|██████████| 72/72 [00:00<00:00, 784.70it/s]


save model
Epoch 64. accuracy: 0.7937828371278459


100%|██████████| 167/167 [00:00<00:00, 354.76it/s]
100%|██████████| 72/72 [00:00<00:00, 869.79it/s]


Epoch 65. accuracy: 0.792031523642732


100%|██████████| 167/167 [00:00<00:00, 348.45it/s]
100%|██████████| 72/72 [00:00<00:00, 793.33it/s]


Epoch 66. accuracy: 0.792031523642732


100%|██████████| 167/167 [00:00<00:00, 314.16it/s]
100%|██████████| 72/72 [00:00<00:00, 802.14it/s]


Epoch 67. accuracy: 0.787215411558669


100%|██████████| 167/167 [00:00<00:00, 349.17it/s]
100%|██████████| 72/72 [00:00<00:00, 829.80it/s]


Epoch 68. accuracy: 0.7889667250437828


100%|██████████| 167/167 [00:00<00:00, 345.20it/s]
100%|██████████| 72/72 [00:00<00:00, 843.75it/s]


Epoch 69. accuracy: 0.7902802101576182


100%|██████████| 167/167 [00:00<00:00, 371.28it/s]
100%|██████████| 72/72 [00:00<00:00, 744.64it/s]


Epoch 70. accuracy: 0.7889667250437828


100%|██████████| 167/167 [00:00<00:00, 334.14it/s]
100%|██████████| 72/72 [00:00<00:00, 586.93it/s]


Epoch 71. accuracy: 0.7889667250437828


100%|██████████| 167/167 [00:00<00:00, 322.01it/s]
100%|██████████| 72/72 [00:00<00:00, 849.33it/s]


Epoch 72. accuracy: 0.7907180385288967


100%|██████████| 167/167 [00:00<00:00, 337.20it/s]
100%|██████████| 72/72 [00:00<00:00, 700.90it/s]


Epoch 73. accuracy: 0.7898423817863398


100%|██████████| 167/167 [00:00<00:00, 355.51it/s]
100%|██████████| 72/72 [00:00<00:00, 788.82it/s]


Epoch 74. accuracy: 0.7863397548161121


100%|██████████| 167/167 [00:00<00:00, 343.13it/s]
100%|██████████| 72/72 [00:00<00:00, 736.66it/s]


Epoch 75. accuracy: 0.792031523642732


100%|██████████| 167/167 [00:00<00:00, 320.46it/s]
100%|██████████| 72/72 [00:00<00:00, 829.81it/s]


Epoch 76. accuracy: 0.792031523642732


100%|██████████| 167/167 [00:00<00:00, 364.94it/s]
100%|██████████| 72/72 [00:00<00:00, 820.37it/s]


Epoch 77. accuracy: 0.792031523642732


100%|██████████| 167/167 [00:00<00:00, 367.93it/s]
100%|██████████| 72/72 [00:00<00:00, 869.79it/s]


Epoch 78. accuracy: 0.7894045534150613


100%|██████████| 167/167 [00:00<00:00, 367.21it/s]
100%|██████████| 72/72 [00:00<00:00, 859.44it/s]


Epoch 79. accuracy: 0.7915936952714536


100%|██████████| 167/167 [00:00<00:00, 320.98it/s]
100%|██████████| 72/72 [00:00<00:00, 668.45it/s]


Epoch 80. accuracy: 0.792031523642732


100%|██████████| 167/167 [00:00<00:00, 337.59it/s]
100%|██████████| 72/72 [00:00<00:00, 694.17it/s]


save model
Epoch 81. accuracy: 0.7942206654991243


100%|██████████| 167/167 [00:00<00:00, 320.17it/s]
100%|██████████| 72/72 [00:00<00:00, 736.66it/s]


Epoch 82. accuracy: 0.7911558669001751


100%|██████████| 167/167 [00:00<00:00, 330.27it/s]
100%|██████████| 72/72 [00:00<00:00, 849.33it/s]


Epoch 83. accuracy: 0.7902802101576182


100%|██████████| 167/167 [00:00<00:00, 360.49it/s]
100%|██████████| 72/72 [00:00<00:00, 829.80it/s]


Epoch 84. accuracy: 0.7902802101576182


100%|██████████| 167/167 [00:00<00:00, 318.59it/s]
100%|██████████| 72/72 [00:00<00:00, 829.81it/s]


save model
Epoch 85. accuracy: 0.7959719789842382


100%|██████████| 167/167 [00:00<00:00, 344.54it/s]
100%|██████████| 72/72 [00:00<00:00, 603.85it/s]


Epoch 86. accuracy: 0.7933450087565674


100%|██████████| 167/167 [00:00<00:00, 350.31it/s]
100%|██████████| 72/72 [00:00<00:00, 849.33it/s]


Epoch 87. accuracy: 0.7933450087565674


100%|██████████| 167/167 [00:00<00:00, 373.30it/s]
100%|██████████| 72/72 [00:00<00:00, 869.79it/s]


Epoch 88. accuracy: 0.7898423817863398


100%|██████████| 167/167 [00:00<00:00, 371.02it/s]
100%|██████████| 72/72 [00:00<00:00, 829.80it/s]


Epoch 89. accuracy: 0.792907180385289


100%|██████████| 167/167 [00:00<00:00, 337.21it/s]
100%|██████████| 72/72 [00:00<00:00, 627.76it/s]


Epoch 90. accuracy: 0.7942206654991243


100%|██████████| 167/167 [00:00<00:00, 325.77it/s]
100%|██████████| 72/72 [00:00<00:00, 793.33it/s]


Epoch 91. accuracy: 0.7950963222416813


100%|██████████| 167/167 [00:00<00:00, 316.80it/s]
100%|██████████| 72/72 [00:00<00:00, 776.27it/s]


save model
Epoch 92. accuracy: 0.797723292469352


100%|██████████| 167/167 [00:00<00:00, 354.16it/s]
100%|██████████| 72/72 [00:00<00:00, 839.45it/s]


Epoch 93. accuracy: 0.7907180385288967


100%|██████████| 167/167 [00:00<00:00, 346.68it/s]
100%|██████████| 72/72 [00:00<00:00, 721.93it/s]


Epoch 94. accuracy: 0.7898423817863398


100%|██████████| 167/167 [00:00<00:00, 316.00it/s]
100%|██████████| 72/72 [00:00<00:00, 736.66it/s]


Epoch 95. accuracy: 0.7937828371278459


100%|██████████| 167/167 [00:00<00:00, 336.24it/s]
100%|██████████| 72/72 [00:00<00:00, 694.16it/s]


Epoch 96. accuracy: 0.7955341506129597


100%|██████████| 167/167 [00:00<00:00, 343.13it/s]
100%|██████████| 72/72 [00:00<00:00, 752.01it/s]


Epoch 97. accuracy: 0.7972854640980735


100%|██████████| 167/167 [00:00<00:00, 349.96it/s]
100%|██████████| 72/72 [00:00<00:00, 772.17it/s]


Epoch 98. accuracy: 0.7937828371278459


100%|██████████| 167/167 [00:00<00:00, 366.46it/s]
100%|██████████| 72/72 [00:00<00:00, 849.33it/s]

Epoch 99. accuracy: 0.7898423817863398





# Inference

In [14]:
results_predict = []

model.load_state_dict(torch.load('baseline/best+original.pt'))
model.eval()
with torch.no_grad():
    for x in tqdm(test_dataloader):
            x = x.to(device)
            hat_y = model(x)
            hat_y = hat_y.squeeze(-1)
            predictions = (hat_y > 0.5).int() 
            results_predict.append(predictions.cpu())

results_predict = torch.concat(results_predict).tolist()

100%|██████████| 102/102 [00:00<00:00, 498.90it/s]


# Export the results


In [15]:
id = pd.read_csv('nlp-getting-started/test.csv')['id']
prediction = pd.DataFrame()
prediction['id'] = id.values
prediction['target'] = results_predict
prediction.to_csv('prediction_result/prediction_result_baseline.csv',index=False)