## Import Libraries

In [None]:
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset,DataLoader,random_split
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'{device} is running..')

from transformers import BertForSequenceClassification,BertTokenizer,AdamW
from sklearn.metrics import f1_score

import gc,os,random

import wandb
wandb.login()

import re,string
import warnings
warnings.filterwarnings('ignore')

from tqdm import tqdm
import nltk
nltk.download('stopwords')
nltk.download('punkt')
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize

seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
np.random.seed(seed)
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)

In [None]:
sweep_config = {
    'name':'Bert with CE',
    'method':'random',
    'metric':{
        'name':'valid_loss',
        'goal':'minimize'
    },
    'parameters':{
        'learning_rate':{
            'min':1e-7,
            'max':1e-6
        },
        'epochs':{
            'values':[4]
        },
        'batch_size':{
            'values':[16,32]
        },
        'lr':{
            'min':1e-7,
            'max':1e-6
        },
        'eps':{
            'min':1e-9,
            'max':1e-8
        }
    }
}

## Data

In [None]:
train_data = pd.read_csv('./Data/train.csv')
test_data = pd.read_csv('./Data/test.csv')
train_len = len(train_data)

all_data = pd.concat([train_data,test_data])

In [None]:
stop = set(stopwords.words('english'))

In [None]:
def remove_tag(text):
    tag = re.compile(r'@\S+')
    return re.sub(tag,'',text)

def remove_URL(text):
    url = re.compile(r'https?://\S+|www\.\S+')
    return re.sub(url,'',text)

def remove_html(text):
    html = re.compile(r'<[^>]+>|\([^)]+\)')
    return re.sub(html,'',text)

def remove_punct(text):
    punctuation = list(string.punctuation)
    table = str.maketrans('','',''.join(punctuation))
    return text.translate(table)

In [None]:
all_data['cleaned'] = all_data['text'].apply(lambda x:remove_tag(x))
all_data['cleaned'] = all_data['cleaned'].apply(lambda x:remove_URL(x))
all_data['cleaned'] = all_data['cleaned'].apply(lambda x:remove_punct(x))
all_data['cleaned'] = all_data['cleaned'].apply(lambda x:x.lower())
all_data['cleaned'] = all_data['cleaned'].apply(lambda x:word_tokenize(x))
all_data['cleaned'] = all_data['cleaned'].apply(lambda x:' '.join([word for word in x if word not in stop]))

In [None]:
train_data,test_data = all_data[:train_len],all_data[train_len:]

In [None]:
class TweetsDataset(Dataset):
    def __init__(self,df,tokenizer,label):
        self.df = df
        self.tokenizer = tokenizer
        self.label = label

    def __len__(self):
        return len(self.df)

    def __getitem__(self,idx):
        text = self.df.loc[idx]['cleaned']

        encoded_dict = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            truncation=True,
            max_length=84,
            padding='max_length',
            return_tensors='pt',
            return_attention_mask=True
        )

        if self.label:
            labels = self.df.loc[idx]['target']
            return {'input_ids':encoded_dict['input_ids'].squeeze(),
                    'attention_mask':encoded_dict['attention_mask'].squeeze(),
                    'labels':torch.tensor(labels,dtype=torch.long)}
        else:
            return {'input_ids':encoded_dict['input_ids'].squeeze(),
                    'attention_mask':encoded_dict['attention_mask'].squeeze()}

In [None]:
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)

In [None]:
def TweetsLoader(train_data,test_data,batch_size):
    train_dataset = TweetsDataset(train_data,tokenizer,True)
    test_dataset = TweetsDataset(test_data,tokenizer,False)
    train_size = int(len(train_data) * 0.8)
    valid_size = len(train_data) - train_size

    train_dataset,valid_dataset = random_split(train_dataset,[train_size,valid_size])

    train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,pin_memory=True)
    valid_loader = DataLoader(valid_dataset,batch_size=batch_size,shuffle=False,pin_memory=True)
    test_loader = DataLoader(test_dataset,batch_size=1,shuffle=False)
    return train_loader,valid_loader,test_loader

## Model

In [None]:
class TweetsModel(nn.Module):
    def __init__(self,model_name):
        super().__init__()
        self.model = BertForSequenceClassification.from_pretrained(model_name)

    def forward(self,input_ids,attention_mask):
        output = self.model(input_ids,attention_mask)
        return output.logits

## Train

In [None]:
def train_valid(model,train_loader,valid_loader,criterion,optimizer,metric,epochs):
    wandb.watch(model,criterion,log='all',log_freq=10)

    valid_loss_list = [1]
    for epoch in range(epochs):
        gc.collect()
        pbar = tqdm(train_loader,desc='Training..')
        
        train_loss = 0
        train_step = 0
        for step,batch in enumerate(pbar):
            train_step += 1
            optimizer.zero_grad()
            
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            logits = model(input_ids,attention_mask)
            loss = criterion(logits,labels)

            loss.backward()
            optimizer.step()
            model.zero_grad()

            train_loss += loss.detach().cpu().numpy().item()

            pbar.set_postfix({'train_loss':train_loss/train_step})
            wandb.log({'train_loss':train_loss/train_step})
        print(f'Epoch [{epoch+1}/{epochs}] Train_loss: {train_loss/train_step}')
        pbar.close()

        with torch.no_grad():
            model.eval()
            pbar = tqdm(valid_loader,desc='Validating..')

            y_pred,y_true = [],[]

            valid_loss = 0
            valid_step = 0
            for step,batch in enumerate(pbar):
                valid_step += 1

                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)

                logits = model(input_ids,attention_mask)
                predictions = torch.argmax(logits,dim=1)

                loss = criterion(logits,labels)
                valid_loss += loss.detach().cpu().numpy().item()

                y_pred.extend(predictions.detach().cpu().numpy())
                y_true.extend(labels.detach().cpu().numpy())

            valid_loss /= valid_step
            f1 = f1_score(y_true,y_pred)

            if valid_loss < min(valid_loss_list):
                print('model improved!')
            else:
                print('model "not" improved..')
            valid_loss_list.append(valid_loss)

            wandb.log({'valid_loss':valid_loss})
            wandb.log({'valid_score':f1})

            print(f'Epoch [{epoch+1}/{epochs}] Score: {f1}')
            print(f'Epoch [{epoch+1}/{epochs}] Valid_loss: {valid_loss}')
        print('='*100)
    print('Train/Valid completed')
    
    del model,train_loader,valid_loader,criterion,optimizer,metric
    gc.collect()

## Sweep

In [None]:
def run_sweep(config=None):
    with wandb.init(config=config) as run:
        torch.cuda.empty_cache()
        gc.collect()

        run.name = 'Bert_base_raw'

        w_config = wandb.config

        criterion = nn.CrossEntropyLoss()
        train_loader,valid_loader,test_loader = TweetsLoader(train_data,test_data,w_config.batch_size)
        model = TweetsModel('bert-base-uncased').to(device)
        optimizer = AdamW(model.parameters(),lr=w_config.lr,eps=w_config.eps,no_deprecation_warning=True)
        metric = f1_score

        train_valid(model,train_loader,valid_loader,criterion,optimizer,metric,w_config.epochs)
        del criterion,train_loader,valid_loader,model,optimizer,metric
        torch.cuda.empty_cache()
        gc.collect()

In [None]:
sweep_id = wandb.sweep(sweep_config,project='sweep_bert_base',entity='chanmuzi')
wandb.agent(sweep_id,run_sweep,count=15)