# Preparation

In [1]:
%cd ..
%cd ..

d:\杨蕙菡\assignment-2-text-classification-foxintohumanbeing\improvement
d:\杨蕙菡\assignment-2-text-classification-foxintohumanbeing


In [2]:
import torchtext
import os
import numpy as np
import random
import torch
from torch.utils.data import DataLoader, Dataset
from torchtext.data import get_tokenizer
from torchtext.vocab import GloVe
from torch.nn.utils.rnn import pad_sequence
from torch import nn
import pandas as pd
import argparse
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import re
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
import nltk
nltk.download('stopwords')
nltk.download('wordnet')
GLOVE_DIM = 100
GLOVE = GloVe(name='6B', dim=GLOVE_DIM)

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\mayn\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\mayn\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


## Set random seed

In [3]:
seed = 114
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x2461e1fba10>

# Hyperparameter

In [4]:
configs = {
    'work_dir': 'work_dir2', 
    'device': 'cuda:0',
    'batch': 32, 
    'optimizer_config': {
        'lr': 1e-4, 
    }, 
    'epoch': 100, 
    'dropout':0.5
}

In [None]:
GLOVE_DIM = 100
GLOVE = GloVe(name='6B', dim=GLOVE_DIM)

# DataLoader

In [5]:
def clean_text(text):
    # Remove URLs
    text = re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE)

    # Remove special characters, numbers, and punctuation
    text = re.sub(r'\W+', ' ', text)

    # Convert to lowercase
    text = text.lower()

    # Tokenize and remove stopwords
    tokens = text.split()
    tokens = [token for token in tokens if token not in stopwords.words('english')]

    # Lemmatize the tokens
    lemmatizer = WordNetLemmatizer()
    tokens = [lemmatizer.lemmatize(token) for token in tokens]

    return ' '.join(tokens)

In [6]:
class TWITTERDataset(Dataset):

    def __init__(self, fname, is_train=True):
        super().__init__()
        self.tokenizer = get_tokenizer('basic_english')
        self.train = is_train
        if is_train == True:
            df = pd.read_csv(fname).iloc[:,1:]
        else:
            df = pd.read_csv(fname)
        self.lines = []
        for i in range(len(df)):
            cleaned_text = clean_text(df.iloc[i, 3])
            tokenized_text = self.tokenizer(cleaned_text)
            if not tokenized_text:
                tokenized_text = ['<UNK>']
            if is_train == True:
                self.lines.append((df.iloc[i, 0], df.iloc[i, 1], df.iloc[i, 2], GLOVE.get_vecs_by_tokens(tokenized_text), torch.tensor(df.iloc[i, 4], dtype=torch.int32) ))
            else:
                self.lines.append(( GLOVE.get_vecs_by_tokens(self.tokenizer(df.iloc[i, 3]))))
        print('Complete data preprocessing with length:', len(self.lines)) 

    def __len__(self):
        return len(self.lines)

    def __getitem__(self, index: int):
        item = self.lines[index]
        if self.train:
            return item[3], item[4]
        return item
def get_dataloader():
    def collate_fn1(batch):
        x, y = zip(*batch)
        x_pad = pad_sequence(x, batch_first=True)
        y = torch.Tensor(y)
        return x_pad, y

    def collate_fn2(batch):
        x = [item for item in batch]
        x_pad = pad_sequence(x, batch_first=True)
        return x_pad

    train_dataloader = DataLoader(TWITTERDataset('nlp-getting-started/train_clean.csv'),
                    batch_size = configs['batch'],
                    shuffle = True,
                    collate_fn = collate_fn1)
    val_dataloader = DataLoader(TWITTERDataset('nlp-getting-started/val_clean.csv'),
                    batch_size = configs['batch'],
                    shuffle = True,
                    collate_fn = collate_fn1)
    test_dataloader = DataLoader(TWITTERDataset('nlp-getting-started/test.csv', False),
                    batch_size = configs['batch'],
                    shuffle = False,
                    collate_fn = collate_fn2)
    return train_dataloader,val_dataloader, test_dataloader

train_dataloader,val_dataloader, test_dataloader = get_dataloader()


Complete data preprocessing with length: 5329
Complete data preprocessing with length: 2284
Complete data preprocessing with length: 3263


# Define Model

In [7]:
device = configs['device']

class Attention(nn.Module):
    def __init__(self, hidden_units):
        super(Attention, self).__init__()
        self.hidden_units = hidden_units
        self.attn = nn.Linear(self.hidden_units, 1)

    def forward(self, outputs):
        attn_weights = torch.tanh(self.attn(outputs))
        attn_weights = torch.softmax(attn_weights, dim=1)
        context = torch.sum(outputs * attn_weights, dim=1)
        return context, attn_weights

class LSTMWithAttention(torch.nn.Module):
    def __init__(self, hidden_units=64, dropout_rate=0.5):
        super().__init__()
        self.drop = nn.Dropout(dropout_rate)
        self.lstm = nn.LSTM(GLOVE_DIM, hidden_units, 1, batch_first=True)
        self.attention = Attention(hidden_units)
        self.linear = nn.Linear(hidden_units, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor):
        emb = self.drop(x)
        output, _ = self.lstm(emb)
        output, attn_weights = self.attention(output)
        output = self.linear(output)
        output = self.sigmoid(output)
        return output

# Training

In [10]:
writer = SummaryWriter('114154')
model = LSTMWithAttention().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=configs['optimizer_config']['lr'])
citerion = torch.nn.BCELoss()
best_accuracy = 0
for epoch in range(configs['epoch']):
    model.train()
    loss_sum = 0
    dataset_len = len(train_dataloader.dataset)
    for x, y in tqdm(train_dataloader):
        batchsize = y.shape[0]
        x = x.to(device)
        y = y.to(device)
        hat_y = model(x)
        hat_y = hat_y.squeeze(-1)
        loss = citerion(hat_y, y)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        loss_sum += loss * batchsize
    writer.add_scalar('training loss',
                            loss_sum/dataset_len,
                            epoch)

    model.eval()  # Set the model to evaluation mode
    correct_predictions = 0
    total_predictions = 0
    test_loss = 0
    results_predict = []
    with torch.no_grad():
        for x, y in tqdm(val_dataloader):
            x = x.to(device)
            y = y.to(device)

            hat_y = model(x)
            hat_y = hat_y.squeeze(-1)

            loss = citerion(hat_y, y)
            test_loss += loss.item() * y.size(0)

            # Calculate accuracy
            predictions = (hat_y > 0.5).int()  # Convert probabilities to binary predictions
            correct_predictions += (predictions == y).sum().item()
            total_predictions += y.size(0)
            results_predict.append(predictions.cpu())

    accuracy = correct_predictions / total_predictions
    avg_test_loss = test_loss / total_predictions
    writer.add_scalar('average validation accuracy', accuracy, epoch)
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        pt_path = os.path.join('improvement/improvement1', 'best+LSTM.pt')
        torch.save(model.state_dict(), pt_path)
        print('save model')
        results_predict = torch.concat(results_predict).tolist()
        id = pd.read_csv('nlp-getting-started/val_clean.csv')['id']
        prediction = pd.DataFrame()
        prediction['id'] = id.values
        prediction['target'] = results_predict
        prediction.to_csv('improvement/improvement1/validation_result.csv',index=False)

    print(f'Epoch {epoch}. accuracy: {accuracy}')

100%|██████████| 167/167 [00:00<00:00, 226.26it/s]
100%|██████████| 72/72 [00:00<00:00, 821.37it/s]


save model
Epoch 0. accuracy: 0.6160245183887916


100%|██████████| 167/167 [00:00<00:00, 316.77it/s]
100%|██████████| 72/72 [00:00<00:00, 925.55it/s]


save model
Epoch 1. accuracy: 0.7526269702276708


100%|██████████| 167/167 [00:00<00:00, 329.00it/s]
100%|██████████| 72/72 [00:00<00:00, 852.56it/s]


save model
Epoch 2. accuracy: 0.7539404553415061


100%|██████████| 167/167 [00:00<00:00, 339.02it/s]
100%|██████████| 72/72 [00:00<00:00, 925.55it/s]


save model
Epoch 3. accuracy: 0.7622591943957968


100%|██████████| 167/167 [00:00<00:00, 329.26it/s]
100%|██████████| 72/72 [00:00<00:00, 759.92it/s]


save model
Epoch 4. accuracy: 0.7675131348511384


100%|██████████| 167/167 [00:00<00:00, 332.24it/s]
100%|██████████| 72/72 [00:00<00:00, 914.94it/s]


save model
Epoch 5. accuracy: 0.7727670753064798


100%|██████████| 167/167 [00:00<00:00, 332.95it/s]
100%|██████████| 72/72 [00:00<00:00, 917.79it/s]


Epoch 6. accuracy: 0.7714535901926445


100%|██████████| 167/167 [00:00<00:00, 319.55it/s]
100%|██████████| 72/72 [00:00<00:00, 805.74it/s]


save model
Epoch 7. accuracy: 0.7753940455341506


100%|██████████| 167/167 [00:00<00:00, 336.24it/s]
100%|██████████| 72/72 [00:00<00:00, 839.45it/s]


Epoch 8. accuracy: 0.7740805604203153


100%|██████████| 167/167 [00:00<00:00, 309.76it/s]
100%|██████████| 72/72 [00:00<00:00, 694.67it/s]


save model
Epoch 9. accuracy: 0.7806479859894921


100%|██████████| 167/167 [00:00<00:00, 317.26it/s]
100%|██████████| 72/72 [00:00<00:00, 842.61it/s]


save model
Epoch 10. accuracy: 0.7850262697022767


100%|██████████| 167/167 [00:00<00:00, 332.18it/s]
100%|██████████| 72/72 [00:00<00:00, 849.32it/s]


Epoch 11. accuracy: 0.7845884413309983


100%|██████████| 167/167 [00:00<00:00, 312.66it/s]
100%|██████████| 72/72 [00:00<00:00, 859.43it/s]


Epoch 12. accuracy: 0.7793345008756567


100%|██████████| 167/167 [00:00<00:00, 338.28it/s]
100%|██████████| 72/72 [00:00<00:00, 859.44it/s]


Epoch 13. accuracy: 0.782399299474606


100%|██████████| 167/167 [00:00<00:00, 322.91it/s]
100%|██████████| 72/72 [00:00<00:00, 802.14it/s]


Epoch 14. accuracy: 0.7819614711033275


100%|██████████| 167/167 [00:00<00:00, 285.26it/s]
100%|██████████| 72/72 [00:00<00:00, 797.09it/s]


Epoch 15. accuracy: 0.7845884413309983


100%|██████████| 167/167 [00:00<00:00, 295.32it/s]
100%|██████████| 72/72 [00:00<00:00, 811.16it/s]


Epoch 16. accuracy: 0.7841506129597198


100%|██████████| 167/167 [00:00<00:00, 314.41it/s]
100%|██████████| 72/72 [00:00<00:00, 827.29it/s]


Epoch 17. accuracy: 0.7845884413309983


100%|██████████| 167/167 [00:00<00:00, 308.94it/s]
100%|██████████| 72/72 [00:00<00:00, 654.17it/s]


save model
Epoch 18. accuracy: 0.7863397548161121


100%|██████████| 167/167 [00:00<00:00, 290.43it/s]
100%|██████████| 72/72 [00:00<00:00, 752.01it/s]


Epoch 19. accuracy: 0.7832749562171629


100%|██████████| 167/167 [00:00<00:00, 289.43it/s]
100%|██████████| 72/72 [00:00<00:00, 694.16it/s]


Epoch 20. accuracy: 0.7802101576182137


100%|██████████| 167/167 [00:00<00:00, 280.01it/s]
100%|██████████| 72/72 [00:00<00:00, 820.38it/s]


Epoch 21. accuracy: 0.7819614711033275


100%|██████████| 167/167 [00:00<00:00, 332.90it/s]
100%|██████████| 72/72 [00:00<00:00, 839.45it/s]


Epoch 22. accuracy: 0.7841506129597198


100%|██████████| 167/167 [00:00<00:00, 306.92it/s]
100%|██████████| 72/72 [00:00<00:00, 776.27it/s]


save model
Epoch 23. accuracy: 0.787215411558669


100%|██████████| 167/167 [00:00<00:00, 300.08it/s]
100%|██████████| 72/72 [00:00<00:00, 736.66it/s]


save model
Epoch 24. accuracy: 0.7889667250437828


100%|██████████| 167/167 [00:00<00:00, 293.47it/s]
100%|██████████| 72/72 [00:00<00:00, 638.87it/s]


Epoch 25. accuracy: 0.787215411558669


100%|██████████| 167/167 [00:00<00:00, 323.53it/s]
100%|██████████| 72/72 [00:00<00:00, 839.45it/s]


Epoch 26. accuracy: 0.7876532399299475


100%|██████████| 167/167 [00:00<00:00, 336.92it/s]
100%|██████████| 72/72 [00:00<00:00, 820.37it/s]


Epoch 27. accuracy: 0.7837127845884413


100%|██████████| 167/167 [00:00<00:00, 329.91it/s]
100%|██████████| 72/72 [00:00<00:00, 849.33it/s]


Epoch 28. accuracy: 0.7854640980735552


100%|██████████| 167/167 [00:00<00:00, 320.78it/s]
100%|██████████| 72/72 [00:00<00:00, 647.25it/s]


Epoch 29. accuracy: 0.7859019264448336


100%|██████████| 167/167 [00:00<00:00, 287.71it/s]
100%|██████████| 72/72 [00:00<00:00, 700.90it/s]


Epoch 30. accuracy: 0.7854640980735552


100%|██████████| 167/167 [00:00<00:00, 304.13it/s]
100%|██████████| 72/72 [00:00<00:00, 736.66it/s]


Epoch 31. accuracy: 0.7880910683012259


100%|██████████| 167/167 [00:00<00:00, 319.13it/s]
100%|██████████| 72/72 [00:00<00:00, 784.70it/s]


Epoch 32. accuracy: 0.7876532399299475


100%|██████████| 167/167 [00:00<00:00, 311.86it/s]
100%|██████████| 72/72 [00:00<00:00, 839.45it/s]


Epoch 33. accuracy: 0.7837127845884413


100%|██████████| 167/167 [00:00<00:00, 313.83it/s]
100%|██████████| 72/72 [00:00<00:00, 707.78it/s]


Epoch 34. accuracy: 0.7885288966725044


100%|██████████| 167/167 [00:00<00:00, 313.57it/s]
100%|██████████| 72/72 [00:00<00:00, 839.46it/s]


Epoch 35. accuracy: 0.7841506129597198


100%|██████████| 167/167 [00:00<00:00, 334.52it/s]
100%|██████████| 72/72 [00:00<00:00, 839.45it/s]


Epoch 36. accuracy: 0.7885288966725044


100%|██████████| 167/167 [00:00<00:00, 332.24it/s]
100%|██████████| 72/72 [00:00<00:00, 790.69it/s]


Epoch 37. accuracy: 0.7867775831873906


100%|██████████| 167/167 [00:00<00:00, 323.39it/s]
100%|██████████| 72/72 [00:00<00:00, 700.90it/s]


Epoch 38. accuracy: 0.7845884413309983


100%|██████████| 167/167 [00:00<00:00, 303.83it/s]
100%|██████████| 72/72 [00:00<00:00, 596.64it/s]


Epoch 39. accuracy: 0.7880910683012259


100%|██████████| 167/167 [00:00<00:00, 277.70it/s]
100%|██████████| 72/72 [00:00<00:00, 788.34it/s]


save model
Epoch 40. accuracy: 0.7911558669001751


100%|██████████| 167/167 [00:00<00:00, 310.66it/s]
100%|██████████| 72/72 [00:00<00:00, 784.71it/s]


Epoch 41. accuracy: 0.7867775831873906


100%|██████████| 167/167 [00:00<00:00, 286.93it/s]
100%|██████████| 72/72 [00:00<00:00, 776.27it/s]


Epoch 42. accuracy: 0.7885288966725044


100%|██████████| 167/167 [00:00<00:00, 287.71it/s]
100%|██████████| 72/72 [00:00<00:00, 649.33it/s]


Epoch 43. accuracy: 0.7876532399299475


100%|██████████| 167/167 [00:00<00:00, 282.85it/s]
100%|██████████| 72/72 [00:00<00:00, 687.55it/s]


Epoch 44. accuracy: 0.7911558669001751


100%|██████████| 167/167 [00:00<00:00, 265.11it/s]
100%|██████████| 72/72 [00:00<00:00, 606.66it/s]


Epoch 45. accuracy: 0.7885288966725044


100%|██████████| 167/167 [00:00<00:00, 273.34it/s]
100%|██████████| 72/72 [00:00<00:00, 622.35it/s]


save model
Epoch 46. accuracy: 0.7915936952714536


100%|██████████| 167/167 [00:00<00:00, 296.19it/s]
100%|██████████| 72/72 [00:00<00:00, 711.16it/s]


Epoch 47. accuracy: 0.7841506129597198


100%|██████████| 167/167 [00:00<00:00, 303.24it/s]
100%|██████████| 72/72 [00:00<00:00, 802.14it/s]


Epoch 48. accuracy: 0.7885288966725044


100%|██████████| 167/167 [00:00<00:00, 309.04it/s]
100%|██████████| 72/72 [00:00<00:00, 606.66it/s]


Epoch 49. accuracy: 0.7880910683012259


100%|██████████| 167/167 [00:00<00:00, 290.60it/s]
100%|██████████| 72/72 [00:00<00:00, 644.58it/s]


Epoch 50. accuracy: 0.7885288966725044


100%|██████████| 167/167 [00:00<00:00, 241.07it/s]
100%|██████████| 72/72 [00:00<00:00, 644.58it/s]


Epoch 51. accuracy: 0.787215411558669


100%|██████████| 167/167 [00:00<00:00, 274.24it/s]
100%|██████████| 72/72 [00:00<00:00, 656.30it/s]


Epoch 52. accuracy: 0.7889667250437828


100%|██████████| 167/167 [00:00<00:00, 290.92it/s]
100%|██████████| 72/72 [00:00<00:00, 784.70it/s]


Epoch 53. accuracy: 0.7907180385288967


100%|██████████| 167/167 [00:00<00:00, 311.24it/s]
100%|██████████| 72/72 [00:00<00:00, 802.14it/s]


Epoch 54. accuracy: 0.7902802101576182


100%|██████████| 167/167 [00:00<00:00, 285.96it/s]
100%|██████████| 72/72 [00:00<00:00, 768.01it/s]


Epoch 55. accuracy: 0.7907180385288967


100%|██████████| 167/167 [00:00<00:00, 286.94it/s]
100%|██████████| 72/72 [00:00<00:00, 668.45it/s]


Epoch 56. accuracy: 0.7902802101576182


100%|██████████| 167/167 [00:00<00:00, 294.80it/s]
100%|██████████| 72/72 [00:00<00:00, 721.92it/s]


Epoch 57. accuracy: 0.7898423817863398


100%|██████████| 167/167 [00:00<00:00, 287.43it/s]
100%|██████████| 72/72 [00:00<00:00, 744.26it/s]


save model
Epoch 58. accuracy: 0.7924693520140105


100%|██████████| 167/167 [00:00<00:00, 286.93it/s]
100%|██████████| 72/72 [00:00<00:00, 638.88it/s]


Epoch 59. accuracy: 0.7915936952714536


100%|██████████| 167/167 [00:00<00:00, 274.50it/s]
100%|██████████| 72/72 [00:00<00:00, 616.83it/s]


save model
Epoch 60. accuracy: 0.7942206654991243


100%|██████████| 167/167 [00:00<00:00, 271.04it/s]
100%|██████████| 72/72 [00:00<00:00, 638.87it/s]


save model
Epoch 61. accuracy: 0.7955341506129597


100%|██████████| 167/167 [00:00<00:00, 284.99it/s]
100%|██████████| 72/72 [00:00<00:00, 668.45it/s]


save model
Epoch 62. accuracy: 0.797723292469352


100%|██████████| 167/167 [00:00<00:00, 283.06it/s]
100%|██████████| 72/72 [00:00<00:00, 687.55it/s]


Epoch 63. accuracy: 0.7911558669001751


100%|██████████| 167/167 [00:00<00:00, 281.90it/s]
100%|██████████| 72/72 [00:00<00:00, 674.70it/s]


Epoch 64. accuracy: 0.7911558669001751


100%|██████████| 167/167 [00:00<00:00, 287.43it/s]
100%|██████████| 72/72 [00:00<00:00, 674.70it/s]


Epoch 65. accuracy: 0.7968476357267951


100%|██████████| 167/167 [00:00<00:00, 292.96it/s]
100%|██████████| 72/72 [00:00<00:00, 656.30it/s]


Epoch 66. accuracy: 0.7959719789842382


100%|██████████| 167/167 [00:00<00:00, 259.60it/s]
100%|██████████| 72/72 [00:00<00:00, 584.20it/s]


Epoch 67. accuracy: 0.7955341506129597


100%|██████████| 167/167 [00:00<00:00, 274.95it/s]
100%|██████████| 72/72 [00:00<00:00, 776.27it/s]


Epoch 68. accuracy: 0.7942206654991243


100%|██████████| 167/167 [00:00<00:00, 301.95it/s]
100%|██████████| 72/72 [00:00<00:00, 776.26it/s]


Epoch 69. accuracy: 0.792907180385289


100%|██████████| 167/167 [00:00<00:00, 309.51it/s]
100%|██████████| 72/72 [00:00<00:00, 665.00it/s]


Epoch 70. accuracy: 0.7955341506129597


100%|██████████| 167/167 [00:00<00:00, 273.16it/s]
100%|██████████| 72/72 [00:00<00:00, 820.37it/s]


save model
Epoch 71. accuracy: 0.8021015761821366


100%|██████████| 167/167 [00:00<00:00, 303.58it/s]
100%|██████████| 72/72 [00:00<00:00, 802.14it/s]


Epoch 72. accuracy: 0.797723292469352


100%|██████████| 167/167 [00:00<00:00, 311.13it/s]
100%|██████████| 72/72 [00:00<00:00, 755.45it/s]


Epoch 73. accuracy: 0.797723292469352


100%|██████████| 167/167 [00:00<00:00, 268.44it/s]
100%|██████████| 72/72 [00:00<00:00, 656.30it/s]


Epoch 74. accuracy: 0.7994746059544658


100%|██████████| 167/167 [00:00<00:00, 274.14it/s]
100%|██████████| 72/72 [00:00<00:00, 601.60it/s]


Epoch 75. accuracy: 0.7990367775831874


100%|██████████| 167/167 [00:00<00:00, 277.89it/s]
100%|██████████| 72/72 [00:00<00:00, 707.77it/s]


Epoch 76. accuracy: 0.7972854640980735


100%|██████████| 167/167 [00:00<00:00, 297.07it/s]
100%|██████████| 72/72 [00:00<00:00, 687.55it/s]


Epoch 77. accuracy: 0.8012259194395797


100%|██████████| 167/167 [00:00<00:00, 260.83it/s]
100%|██████████| 72/72 [00:00<00:00, 752.01it/s]


Epoch 78. accuracy: 0.7999124343257443


100%|██████████| 167/167 [00:00<00:00, 303.86it/s]
100%|██████████| 72/72 [00:00<00:00, 784.71it/s]


Epoch 79. accuracy: 0.7964098073555166


100%|██████████| 167/167 [00:00<00:00, 309.11it/s]
100%|██████████| 72/72 [00:00<00:00, 784.71it/s]


Epoch 80. accuracy: 0.7946584938704028


100%|██████████| 167/167 [00:00<00:00, 279.75it/s]
100%|██████████| 72/72 [00:00<00:00, 644.58it/s]


Epoch 81. accuracy: 0.792031523642732


100%|██████████| 167/167 [00:00<00:00, 273.42it/s]
100%|██████████| 72/72 [00:00<00:00, 662.32it/s]


Epoch 82. accuracy: 0.7981611208406305


100%|██████████| 167/167 [00:00<00:00, 290.20it/s]
100%|██████████| 72/72 [00:00<00:00, 690.52it/s]


Epoch 83. accuracy: 0.8021015761821366


100%|██████████| 167/167 [00:00<00:00, 259.39it/s]
100%|██████████| 72/72 [00:00<00:00, 662.32it/s]


Epoch 84. accuracy: 0.7990367775831874


100%|██████████| 167/167 [00:00<00:00, 262.25it/s]
100%|██████████| 72/72 [00:00<00:00, 707.77it/s]


Epoch 85. accuracy: 0.7999124343257443


100%|██████████| 167/167 [00:00<00:00, 285.20it/s]
100%|██████████| 72/72 [00:00<00:00, 714.79it/s]


Epoch 86. accuracy: 0.7968476357267951


100%|██████████| 167/167 [00:00<00:00, 305.31it/s]
100%|██████████| 72/72 [00:00<00:00, 744.25it/s]


Epoch 87. accuracy: 0.7999124343257443


100%|██████████| 167/167 [00:00<00:00, 267.67it/s]
100%|██████████| 72/72 [00:00<00:00, 638.87it/s]


Epoch 88. accuracy: 0.7981611208406305


100%|██████████| 167/167 [00:00<00:00, 278.35it/s]
100%|██████████| 72/72 [00:00<00:00, 611.80it/s]


Epoch 89. accuracy: 0.8012259194395797


100%|██████████| 167/167 [00:00<00:00, 273.61it/s]
100%|██████████| 72/72 [00:00<00:00, 608.86it/s]


Epoch 90. accuracy: 0.797723292469352


100%|██████████| 167/167 [00:00<00:00, 288.92it/s]
100%|██████████| 72/72 [00:00<00:00, 650.39it/s]


Epoch 91. accuracy: 0.7955341506129597


100%|██████████| 167/167 [00:00<00:00, 279.90it/s]
100%|██████████| 72/72 [00:00<00:00, 674.70it/s]


Epoch 92. accuracy: 0.8012259194395797


100%|██████████| 167/167 [00:00<00:00, 285.49it/s]
100%|██████████| 72/72 [00:00<00:00, 681.07it/s]


Epoch 93. accuracy: 0.7985989492119089


100%|██████████| 167/167 [00:00<00:00, 295.44it/s]
100%|██████████| 72/72 [00:00<00:00, 668.45it/s]


Epoch 94. accuracy: 0.8016637478108581


100%|██████████| 167/167 [00:00<00:00, 271.96it/s]
100%|██████████| 72/72 [00:00<00:00, 656.30it/s]


Epoch 95. accuracy: 0.8021015761821366


100%|██████████| 167/167 [00:00<00:00, 281.63it/s]
100%|██████████| 72/72 [00:00<00:00, 662.32it/s]


save model
Epoch 96. accuracy: 0.8025394045534151


100%|██████████| 167/167 [00:00<00:00, 284.90it/s]
100%|██████████| 72/72 [00:00<00:00, 617.03it/s]


Epoch 97. accuracy: 0.8003502626970228


100%|██████████| 167/167 [00:00<00:00, 273.79it/s]
100%|██████████| 72/72 [00:00<00:00, 622.35it/s]


Epoch 98. accuracy: 0.7907180385288967


100%|██████████| 167/167 [00:00<00:00, 271.29it/s]
100%|██████████| 72/72 [00:00<00:00, 638.88it/s]

Epoch 99. accuracy: 0.7946584938704028





# Inference

In [11]:
results_predict = []

model.load_state_dict(torch.load('improvement/improvement1/best+LSTM.pt'))
model.eval()
with torch.no_grad():
    for x in tqdm(test_dataloader):
            x = x.to(device)
            hat_y = model(x)
            hat_y = hat_y.squeeze(-1)
            predictions = (hat_y > 0.5).int() 
            results_predict.append(predictions.cpu())

results_predict = torch.concat(results_predict).tolist()

100%|██████████| 102/102 [00:00<00:00, 538.28it/s]


# Export the results

In [13]:
id = pd.read_csv('nlp-getting-started/test.csv')['id']
prediction = pd.DataFrame()
prediction['id'] = id.values
prediction['target'] = results_predict
prediction.to_csv('prediction_result/prediction_result_LSTM.csv',index=False)