<a href="https://colab.research.google.com/github/jwengr/dacon/blob/main/%EC%86%8C%EC%84%A4%20%EC%9E%91%EA%B0%80%20%EB%B6%84%EB%A5%98%20AI%20%EA%B2%BD%EC%A7%84%EB%8C%80%ED%9A%8C/DL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
import re
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.externals import joblib
import matplotlib.pyplot as plt
% matplotlib inline

In [None]:
!pip install torchcontrib

In [3]:
import torch.nn
from torch.nn import CrossEntropyLoss
import torch.nn.functional
from torch.nn.functional import softmax
from torch.utils.data import Dataset, DataLoader
from torchcontrib.optim import SWA

In [None]:
import torch
if torch.cuda.is_available():     
    device = torch.device("cuda:0")
    print('There are %d GPU(s) available.' % torch.cuda.device_count())
    print('We will use the GPU:', torch.cuda.get_device_name(0))
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

In [None]:
!pip install transformers

In [6]:
from transformers import AdamW, XLNetTokenizer, XLNetModel, XLNetConfig, XLNetForSequenceClassification

In [7]:
defaultpath = 'drive/My Drive/dacon/sosul/dataset'

기본전처리

In [23]:
train_df = pd.read_csv(defaultpath+'/train.csv',encoding='utf-8')
train_df = train_df[train_df['text'].str.contains('\* \*')==False]
train_df['sentencelen'] = train_df['text'].apply(lambda x: len(x.split('.')))
train_df['charlen'] = train_df['text'].apply(lambda x: len(x))
train_df['c/s'] = train_df['charlen']/(train_df['sentencelen']+1)  ## 0으로 나뉘는것을 방지
train_df['upperlen'] = train_df['text'].apply(lambda x: len(re.findall('[A-Z]',x)))
train_df['u/s'] = train_df['upperlen']/(train_df['sentencelen']+1)  ## 0으로 나뉘는것을 방지
train_df['u/s'] = train_df['upperlen']/(train_df['charlen']+1)  ## 0으로 나뉘는것을 방지

train_df_fr = train_df[train_df['text'].str.contains('[à|ä|ö|î|ù|â|Œ|ç|ê|ü|ñ|ô|Æ|œ|ë|æ|é|Ê|è|ì]')].copy()

train, valid = train_test_split(train_df,test_size=0.2, random_state=2021, stratify=train_df['author']) 
tfidfv = TfidfVectorizer(token_pattern="[a-zA-Z]+|\W",max_features=5000,lowercase=True,dtype=np.float32).fit(train_df['text'])

tfidf + mlp

In [48]:
class TfidfTrainDataset(Dataset):
    def __init__(self,tfidfv=None,df=None):
        self.tfidfv = tfidfv
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        df = self.df.iloc[idx,:]
        enc = self.tfidfv.transform([df['text']]).toarray().astype(np.float32)
        enc = torch.from_numpy(enc[0]).tolist()
        df = df.drop(['index','text','sentencelen','charlen','upperlen'])
        add = torch.from_numpy(df.drop('author').values.astype(np.float32)).tolist()
        input_ids = enc+add
        labels = df['author'].astype(np.int32).tolist()
        
        return input_ids,labels

In [49]:
tfidf_train_dataset = TfidfTrainDataset(tfidfv,train)
tfidf_valid_dataset = TfidfTrainDataset(tfidfv,valid)
def collate_fn(batch):
    return list(zip(*batch))
tfidf_train_dataloader = DataLoader(tfidf_train_dataset, batch_size=4, shuffle=True, num_workers=2,collate_fn=collate_fn)
tfidf_valid_dataloader = DataLoader(tfidf_valid_dataset, batch_size=4, shuffle=True, num_workers=2,collate_fn=collate_fn)

In [56]:
class TfidfMLPModel(torch.nn.Module):
    def __init__(self):
        super(TfidfMLPModel, self).__init__()
        self.linear1 = torch.nn.Linear(5002,512)
        self.linear2 = torch.nn.Linear(512,64)
        self.linear3 = torch.nn.Linear(64,5)
        self.drop1 = torch.nn.Dropout()
        self.drop2 = torch.nn.Dropout()

        torch.nn.init.xavier_normal_(self.linear1.weight)
        torch.nn.init.xavier_normal_(self.linear2.weight)
        torch.nn.init.xavier_normal_(self.linear3.weight)
        
    def forward(self, input_ids):
        x = self.linear1(input_ids)
        x = self.drop1(x)
        x = self.linear2(x)
        x = self.drop2(x)
        logits = self.linear3(x)
        return logits

In [57]:
# function to save and load the model form a specific epoch
def save_model(model, save_path, epochs, lowest_eval_loss, train_loss_hist, valid_loss_hist,train_acc_hist,valid_acc_hist):

    model_to_save = model.module if hasattr(model, 'module') else model
    checkpoint = {'epochs': epochs, 
                    'lowest_eval_loss': lowest_eval_loss,
                    'state_dict': model_to_save.state_dict(),
                    'train_loss_hist': train_loss_hist,
                    'valid_loss_hist': valid_loss_hist,
                    'train_acc_hist' : train_acc_hist,
                    'valid_acc_hist' : valid_acc_hist
                }
    torch.save(checkpoint, save_path+'/MLP_e{0}_loss{1:04.4f}_acc{2:04.4f}.pth'.format(epochs,lowest_eval_loss,valid_acc_hist[-1]))
    print("Saving model at epoch {0} with validation loss of {1} vaildation acc of {2}".format(epochs,
                                                                        lowest_eval_loss,valid_acc_hist[-1]))
    return
  
def load_model(save_path):
    checkpoint = torch.load(save_path)
    model_state_dict = checkpoint['state_dict']
    model = TfidfMLPModel()
    model.load_state_dict(model_state_dict)    
    return model, checkpoint

In [58]:
model = TfidfMLPModel()

In [59]:
adamOptimizer = AdamW(model.parameters(),lr = 1e-5, eps = 1e-8, correct_bias=False)
optimizer = SWA(adamOptimizer, swa_start=4, swa_freq=3, swa_lr=1e-5)

In [74]:
def model_train(model, num_epochs,optimizer,
          train_dataloader, valid_dataloader,model_save_path,checkpoint,device="cpu"
          ):
    if checkpoint is None:
        start_epoch=0
        lowest_eval_loss = float('inf')
        train_loss_hist = []
        valid_loss_hist = []
        train_acc_hist = []
        valid_acc_hist = []
    else:
        start_epochs = checkpoint["epochs"]+1
        lowest_eval_loss = checkpoint["lowest_eval_loss"]
        train_loss_hist = checkpoint["train_loss_hist"]
        valid_loss_hist = checkpoint["valid_loss_hist"]
        train_acc_hist = checkpoint["train_acc_hist"]
        vaild_acc_hist = checkpoint["vaild_acc_hist"]

    model.to(device)
    for i in range(start_epoch,num_epochs):
        actual_epoch = start_epoch + i

        model.train()
        tr_acc = 0
        tr_loss = 0
        num_train_samples = 0
        train_bar = tqdm(train_dataloader,desc=f"Epoch {actual_epoch} Train ")
        for step, batch in enumerate(train_bar):
            b_input_ids, b_labels = torch.FloatTensor(batch[0]).to(device), torch.LongTensor(batch[1]).to(device)
            num_train_samples += b_labels.size(0) 

            optimizer.zero_grad()
            logits = model(input_ids=b_input_ids)
            loss = CrossEntropyLoss()(logits, b_labels)

            prediction = logits.data.max(1)[1]
            tr_acc += prediction.eq(b_labels.data).sum().item()
            tr_loss += loss.item()
            train_bar.set_postfix({'train_acc': tr_acc/num_train_samples,'train_loss':tr_loss/num_train_samples})

            loss.backward()
            optimizer.step()

        optimizer.swap_swa_sgd()
        train_loss_hist.append(tr_loss/num_train_samples)
        train_acc_hist.append(tr_acc/num_train_samples)
            
        model.eval()
        eval_loss = 0
        eval_acc = 0
        num_eval_samples = 0
        with torch.no_grad():
            valid_bar = tqdm(valid_dataloader,desc=f"Epoch {actual_epoch} Valid ")
            for batch in valid_bar:
                b_input_ids, b_labels = torch.FloatTensor(batch[0]).to(device), torch.LongTensor(batch[1]).to(device)

                logits = model(input_ids=b_input_ids)
                loss = CrossEntropyLoss()(logits, b_labels)
                prediction = logits.data.max(1)[1]

                eval_acc += prediction.eq(b_labels.data).sum().item()
                eval_loss += loss.item()
                num_eval_samples += b_labels.size(0)
                valid_bar.set_postfix({'valid_acc':eval_acc/num_eval_samples,'valid_loss':eval_loss/num_eval_samples})

            valid_loss_hist.append(eval_loss/num_eval_samples)
            valid_acc_hist.append(eval_acc/num_eval_samples)
            
        if valid_loss_hist[-1] < lowest_eval_loss:
            lowest_eval_loss = valid_loss_hist[-1]
            save_model(model, model_save_path, actual_epoch, lowest_eval_loss, train_loss_hist, valid_loss_hist,train_acc_hist,valid_acc_hist)
    return model

In [None]:
model = model_train(model=model, num_epochs = 100, 
            model_save_path=defaultpath+'/model', checkpoint=None,
            optimizer=optimizer, device=device,
           train_dataloader=tfidf_train_dataloader, valid_dataloader=tfidf_valid_dataloader)

HBox(children=(FloatProgress(value=0.0, description='Epoch 0 Train ', max=10967.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Epoch 0 Valid ', max=2742.0, style=ProgressStyle(descript…


Saving model at epoch 0 with validation loss of 0.3722827439105118 vaildation acc of 0.43621774414151543


HBox(children=(FloatProgress(value=0.0, description='Epoch 1 Train ', max=10967.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Epoch 1 Valid ', max=2742.0, style=ProgressStyle(descript…


Saving model at epoch 1 with validation loss of 0.359951422742129 vaildation acc of 0.45445427190662896


HBox(children=(FloatProgress(value=0.0, description='Epoch 2 Train ', max=10967.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Epoch 2 Valid ', max=2742.0, style=ProgressStyle(descript…


Saving model at epoch 2 with validation loss of 0.3559400081482344 vaildation acc of 0.48071487188839246


HBox(children=(FloatProgress(value=0.0, description='Epoch 3 Train ', max=10967.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Epoch 3 Valid ', max=2742.0, style=ProgressStyle(descript…


Saving model at epoch 3 with validation loss of 0.345298805987486 vaildation acc of 0.4853651864684964


HBox(children=(FloatProgress(value=0.0, description='Epoch 4 Train ', max=10967.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Epoch 4 Valid ', max=2742.0, style=ProgressStyle(descript…


Saving model at epoch 4 with validation loss of 0.3370674541429822 vaildation acc of 0.5235707121364093


HBox(children=(FloatProgress(value=0.0, description='Epoch 5 Train ', max=10967.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Epoch 5 Valid ', max=2742.0, style=ProgressStyle(descript…


Saving model at epoch 5 with validation loss of 0.32122587903300076 vaildation acc of 0.5484635725357891


HBox(children=(FloatProgress(value=0.0, description='Epoch 6 Train ', max=10967.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Epoch 6 Valid ', max=2742.0, style=ProgressStyle(descript…


Saving model at epoch 6 with validation loss of 0.3102338180411989 vaildation acc of 0.5742682593234248


HBox(children=(FloatProgress(value=0.0, description='Epoch 7 Train ', max=10967.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Epoch 7 Valid ', max=2742.0, style=ProgressStyle(descript…


Saving model at epoch 7 with validation loss of 0.292546746536126 vaildation acc of 0.598066928056898


HBox(children=(FloatProgress(value=0.0, description='Epoch 8 Train ', max=10967.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Epoch 8 Valid ', max=2742.0, style=ProgressStyle(descript…


Saving model at epoch 8 with validation loss of 0.28085110849393713 vaildation acc of 0.5942372572262241


HBox(children=(FloatProgress(value=0.0, description='Epoch 9 Train ', max=10967.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Epoch 9 Valid ', max=2742.0, style=ProgressStyle(descript…


Saving model at epoch 9 with validation loss of 0.2637292998686217 vaildation acc of 0.6339928877541716


HBox(children=(FloatProgress(value=0.0, description='Epoch 10 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 10 Valid ', max=2742.0, style=ProgressStyle(descrip…


Saving model at epoch 10 with validation loss of 0.2526717345176187 vaildation acc of 0.6600711224582839


HBox(children=(FloatProgress(value=0.0, description='Epoch 11 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 11 Valid ', max=2742.0, style=ProgressStyle(descrip…


Saving model at epoch 11 with validation loss of 0.2394160405242955 vaildation acc of 0.6776693717516185


HBox(children=(FloatProgress(value=0.0, description='Epoch 12 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 12 Valid ', max=2742.0, style=ProgressStyle(descrip…


Saving model at epoch 12 with validation loss of 0.2312811037940456 vaildation acc of 0.6895231147989422


HBox(children=(FloatProgress(value=0.0, description='Epoch 13 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 13 Valid ', max=2742.0, style=ProgressStyle(descrip…


Saving model at epoch 13 with validation loss of 0.22258839609236097 vaildation acc of 0.7073949120087535


HBox(children=(FloatProgress(value=0.0, description='Epoch 14 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 14 Valid ', max=2742.0, style=ProgressStyle(descrip…


Saving model at epoch 14 with validation loss of 0.21520678212252553 vaildation acc of 0.7157837147807058


HBox(children=(FloatProgress(value=0.0, description='Epoch 15 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 15 Valid ', max=2742.0, style=ProgressStyle(descrip…


Saving model at epoch 15 with validation loss of 0.20814982872584373 vaildation acc of 0.7237166043585301


HBox(children=(FloatProgress(value=0.0, description='Epoch 16 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 16 Valid ', max=2742.0, style=ProgressStyle(descrip…


Saving model at epoch 16 with validation loss of 0.20460471177387554 vaildation acc of 0.727546275189204


HBox(children=(FloatProgress(value=0.0, description='Epoch 17 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 17 Valid ', max=2742.0, style=ProgressStyle(descrip…


Saving model at epoch 17 with validation loss of 0.19908044150885013 vaildation acc of 0.7290963800492386


HBox(children=(FloatProgress(value=0.0, description='Epoch 18 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 18 Valid ', max=2742.0, style=ProgressStyle(descrip…


Saving model at epoch 18 with validation loss of 0.1954215765515706 vaildation acc of 0.7359350779611562


HBox(children=(FloatProgress(value=0.0, description='Epoch 19 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 19 Valid ', max=2742.0, style=ProgressStyle(descrip…


Saving model at epoch 19 with validation loss of 0.1917585916042714 vaildation acc of 0.7432296890672017


HBox(children=(FloatProgress(value=0.0, description='Epoch 20 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 20 Valid ', max=2742.0, style=ProgressStyle(descrip…


Saving model at epoch 20 with validation loss of 0.1884642913407665 vaildation acc of 0.7419531321236437


HBox(children=(FloatProgress(value=0.0, description='Epoch 21 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 21 Valid ', max=2742.0, style=ProgressStyle(descrip…


Saving model at epoch 21 with validation loss of 0.18754649505116336 vaildation acc of 0.7373028175435397


HBox(children=(FloatProgress(value=0.0, description='Epoch 22 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 22 Valid ', max=2742.0, style=ProgressStyle(descrip…


Saving model at epoch 22 with validation loss of 0.18399255911700663 vaildation acc of 0.7479711862861311


HBox(children=(FloatProgress(value=0.0, description='Epoch 23 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 23 Valid ', max=2742.0, style=ProgressStyle(descrip…


Saving model at epoch 23 with validation loss of 0.1813034122797709 vaildation acc of 0.7564511716969089


HBox(children=(FloatProgress(value=0.0, description='Epoch 24 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 24 Valid ', max=2742.0, style=ProgressStyle(descrip…


Saving model at epoch 24 with validation loss of 0.18063166444275192 vaildation acc of 0.7538068751709674


HBox(children=(FloatProgress(value=0.0, description='Epoch 25 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 25 Valid ', max=2742.0, style=ProgressStyle(descrip…


Saving model at epoch 25 with validation loss of 0.17883462900422556 vaildation acc of 0.7580924591957692


HBox(children=(FloatProgress(value=0.0, description='Epoch 26 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 26 Valid ', max=2742.0, style=ProgressStyle(descrip…


Saving model at epoch 26 with validation loss of 0.17653259243980882 vaildation acc of 0.7580012765569436


HBox(children=(FloatProgress(value=0.0, description='Epoch 27 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 27 Valid ', max=2742.0, style=ProgressStyle(descrip…


Saving model at epoch 27 with validation loss of 0.17495821547792417 vaildation acc of 0.7624692258593964


HBox(children=(FloatProgress(value=0.0, description='Epoch 28 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 28 Valid ', max=2742.0, style=ProgressStyle(descrip…


Saving model at epoch 28 with validation loss of 0.1740069727381426 vaildation acc of 0.7610103036381873


HBox(children=(FloatProgress(value=0.0, description='Epoch 29 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 29 Valid ', max=2742.0, style=ProgressStyle(descrip…


Saving model at epoch 29 with validation loss of 0.1724891170307014 vaildation acc of 0.7631075043311754


HBox(children=(FloatProgress(value=0.0, description='Epoch 30 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 30 Valid ', max=2742.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Epoch 31 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 31 Valid ', max=2742.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Epoch 32 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 32 Valid ', max=2742.0, style=ProgressStyle(descrip…


Saving model at epoch 32 with validation loss of 0.17014784803551178 vaildation acc of 0.7648399744688611


HBox(children=(FloatProgress(value=0.0, description='Epoch 33 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 33 Valid ', max=2742.0, style=ProgressStyle(descrip…


Saving model at epoch 33 with validation loss of 0.16959823528483028 vaildation acc of 0.7641105133582566


HBox(children=(FloatProgress(value=0.0, description='Epoch 34 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 34 Valid ', max=2742.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Epoch 35 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 35 Valid ', max=2742.0, style=ProgressStyle(descrip…


Saving model at epoch 35 with validation loss of 0.16857975245657184 vaildation acc of 0.7701285675207441


HBox(children=(FloatProgress(value=0.0, description='Epoch 36 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 36 Valid ', max=2742.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Epoch 37 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 37 Valid ', max=2742.0, style=ProgressStyle(descrip…


Saving model at epoch 37 with validation loss of 0.1676421833141033 vaildation acc of 0.7683049147442327


HBox(children=(FloatProgress(value=0.0, description='Epoch 38 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 38 Valid ', max=2742.0, style=ProgressStyle(descrip…


Saving model at epoch 38 with validation loss of 0.1670577538487943 vaildation acc of 0.7710403939089997


HBox(children=(FloatProgress(value=0.0, description='Epoch 39 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 39 Valid ', max=2742.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Epoch 40 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 40 Valid ', max=2742.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Epoch 41 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 41 Valid ', max=2742.0, style=ProgressStyle(descrip…


Saving model at epoch 41 with validation loss of 0.16663784464467357 vaildation acc of 0.7718610376584298


HBox(children=(FloatProgress(value=0.0, description='Epoch 42 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 42 Valid ', max=2742.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Epoch 43 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 43 Valid ', max=2742.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Epoch 44 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 44 Valid ', max=2742.0, style=ProgressStyle(descrip…


Saving model at epoch 44 with validation loss of 0.16636274451775618 vaildation acc of 0.7728640466855111


HBox(children=(FloatProgress(value=0.0, description='Epoch 45 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 45 Valid ', max=2742.0, style=ProgressStyle(descrip…


Saving model at epoch 45 with validation loss of 0.16627850398402436 vaildation acc of 0.7725904987690344


HBox(children=(FloatProgress(value=0.0, description='Epoch 46 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 46 Valid ', max=2742.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Epoch 47 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 47 Valid ', max=2742.0, style=ProgressStyle(descrip…


Saving model at epoch 47 with validation loss of 0.16613933859420005 vaildation acc of 0.7734111425184644


HBox(children=(FloatProgress(value=0.0, description='Epoch 48 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 48 Valid ', max=2742.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Epoch 49 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 49 Valid ', max=2742.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Epoch 50 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 50 Valid ', max=2742.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Epoch 51 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 51 Valid ', max=2742.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Epoch 52 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 52 Valid ', max=2742.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Epoch 53 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 53 Valid ', max=2742.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Epoch 54 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 54 Valid ', max=2742.0, style=ProgressStyle(descrip…


Saving model at epoch 54 with validation loss of 0.16607663381307722 vaildation acc of 0.7746876994620224


HBox(children=(FloatProgress(value=0.0, description='Epoch 55 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 55 Valid ', max=2742.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Epoch 56 Train ', max=10967.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 56 Valid ', max=2742.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Epoch 57 Train ', max=10967.0, style=ProgressStyle(descri…

tfidf + xlnet

In [13]:
config = XLNetConfig(
    vocab_size= 1002,
    d_model= 32,
    n_layer= 8,
    n_head=16,
    d_inner=128
)
class XLNetForMultiLabelSequenceClassification(torch.nn.Module):
  
    def __init__(self,config):
        super(XLNetForMultiLabelSequenceClassification, self).__init__()
        self.xlnet = XLNetModel(config)
        self.linear = torch.nn.Linear(32, 5)

        torch.nn.init.xavier_normal_(self.linear.weight)

    def forward(self, input_ids, token_type_ids=None,
                attention_mask=None,labels=None):

        last_hidden_state = self.xlnet(input_ids=input_ids,
                                    attention_mask=attention_mask,
                                    token_type_ids=token_type_ids
                                    )
        mean_last_hidden_state = torch.mean(last_hidden_state[0],1)
        logits = self.linear(mean_last_hidden_state)
        return logits

In [14]:
# function to save and load the model form a specific epoch
def save_model(model, save_path, epochs, lowest_eval_loss, train_loss_hist, valid_loss_hist,train_acc_hist,valid_acc_hist):

    model_to_save = model.module if hasattr(model, 'module') else model
    checkpoint = {'epochs': epochs, 
                    'lowest_eval_loss': lowest_eval_loss,
                    'state_dict': model_to_save.state_dict(),
                    'train_loss_hist': train_loss_hist,
                    'valid_loss_hist': valid_loss_hist,
                    'train_acc_hist' : train_acc_hist,
                    'valid_acc_hist' : valid_acc_hist
                }
    torch.save(checkpoint, save_path+'/e{0}_loss{1:04.4f}_acc{2:04.4f}.pth'.format(epochs,lowest_eval_loss,valid_acc_hist[-1]))
    print("Saving model at epoch {0} with validation loss of {1} vaildation acc of {2}".format(epochs,
                                                                        lowest_eval_loss,valid_acc_hist[-1]))
    return
  
def load_model(save_path):
    checkpoint = torch.load(save_path)
    model_state_dict = checkpoint['state_dict']
    model = T5ForMultiLabelSequenceClassification(config=config)
    model.load_state_dict(model_state_dict)    
    return model, checkpoint

In [15]:
model = XLNetForMultiLabelSequenceClassification(config=config)

Stochastic Weight Averaging

In [16]:
adamOptimizer = AdamW(model.parameters(),lr = 1e-5, eps = 1e-8, correct_bias=False)
optimizer = SWA(adamOptimizer, swa_start=4, swa_freq=3, swa_lr=1e-5)

In [17]:
def model_train(model, num_epochs,optimizer,
          train_dataloader, valid_dataloader,model_save_path,checkpoint,device="cpu"
          ):
    if checkpoint is None:
        start_epoch=0
        lowest_eval_loss = float('inf')
        train_loss_hist = []
        valid_loss_hist = []
        train_acc_hist = []
        valid_acc_hist = []
    else:
        start_epochs = checkpoint["epochs"]+1
        lowest_eval_loss = checkpoint["lowest_eval_loss"]
        train_loss_hist = checkpoint["train_loss_hist"]
        valid_loss_hist = checkpoint["valid_loss_hist"]
        train_acc_hist = checkpoint["train_acc_hist"]
        vaild_acc_hist = checkpoint["vaild_acc_hist"]

    model.to(device)
    for i in range(start_epoch,num_epochs):
        actual_epoch = start_epoch + i

        model.train()
        tr_acc = 0
        tr_loss = 0
        num_train_samples = 0
        train_bar = tqdm(train_dataloader,desc=f"Epoch {actual_epoch} Train ")
        for step, batch in enumerate(train_bar):
            batch = (torch.LongTensor(b).to(device) for b in batch)
            b_input_ids, b_input_mask, b_labels = batch
            num_train_samples += b_labels.size(0) 

            optimizer.zero_grad()
            logits = model(input_ids=b_input_ids, attention_mask=b_input_mask, labels=b_labels)
            loss = CrossEntropyLoss()(logits, b_labels)

            prediction = logits.data.max(1)[1]
            tr_acc += prediction.eq(b_labels.data).sum().item()
            tr_loss += loss.item()
            train_bar.set_postfix({'train_acc': tr_acc/num_train_samples,'train_loss':tr_loss/num_train_samples})

            loss.backward()
            optimizer.step()

        optimizer.swap_swa_sgd()
        train_loss_hist.append(tr_loss/num_train_samples)
        train_acc_hist.append(tr_acc/num_train_samples)
            
        model.eval()
        eval_loss = 0
        eval_acc = 0
        num_eval_samples = 0
        with torch.no_grad():
            valid_bar = tqdm(valid_dataloader,desc=f"Epoch {actual_epoch} Valid ")
            for batch in valid_bar:
                batch = (torch.LongTensor(b).to(device) for b in batch)
                b_input_ids, b_input_mask, b_labels = batch

                logits = model(input_ids=b_input_ids, attention_mask=b_input_mask, labels=b_labels)
                loss = CrossEntropyLoss()(logits, b_labels)
                prediction = logits.data.max(1)[1]

                eval_acc += prediction.eq(b_labels.data).sum().item()
                eval_loss += loss.item()
                num_eval_samples += b_labels.size(0)
                valid_bar.set_postfix({'valid_acc':eval_acc/num_eval_samples,'valid_loss':eval_loss/num_eval_samples})

            valid_loss_hist.append(eval_loss/num_eval_samples)
            valid_acc_hist.append(eval_acc/num_eval_samples)
            
        if valid_loss_hist[-1] < lowest_eval_loss:
            lowest_eval_loss = valid_loss_hist[-1]
            save_model(model, model_save_path, actual_epoch, lowest_eval_loss, train_loss_hist, valid_loss_hist,train_acc_hist,valid_acc_hist)
    return model

In [18]:
model = model_train(model=model, num_epochs = 20, 
            model_save_path=defaultpath+'/model', checkpoint=None,
            optimizer=optimizer, device=device,
           train_dataloader=tfidf_train_dataloader, valid_dataloader=tfidf_valid_dataloader)

HBox(children=(FloatProgress(value=0.0, description='Epoch 0 Train ', max=10967.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Epoch 0 Valid ', max=2742.0, style=ProgressStyle(descript…


Saving model at epoch 0 with validation loss of 0.3925555369755493 vaildation acc of 0.27464210814260964


HBox(children=(FloatProgress(value=0.0, description='Epoch 1 Train ', max=10967.0, style=ProgressStyle(descrip…

KeyboardInterrupt: ignored