In [31]:
import pickle
from dataclasses import dataclass, field
from itertools import islice
from typing import Literal, Callable

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import tqdm
import wandb
from peft import LoraConfig, get_peft_model
from torch import Tensor
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import AutoTokenizer, AutoModel

In [32]:
class MyModel(nn.Module):
    def __init__(self, embedding_model, head_model):
        super().__init__()
        self.embedding_model = embedding_model
        self.head_model = head_model
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        embedding = self.embedding_model(**x).last_hidden_state[:, 0, :]
        output = self.head_model(embedding)
        output = self.sigmoid(output)
        return output

In [33]:
def train_step(model, dataloader):
    model.train()
    losses = []
    for batch in tqdm.tqdm(dataloader):
        x, y_true = batch['tokens'], batch['labels']
        x = x.to(Config.device)
        y_true = y_true.to(Config.device)
        optimizer.zero_grad()
        y_pred = model(x)

        loss = criterion(y_pred, y_true)
        losses.append(loss.item())

        loss.backward()
        optimizer.step()

        del batch
        del x
        del y_true
        del y_pred
        torch.cuda.empty_cache()

    mean_loss = np.array(losses).mean()
    if Config.use_wandb:
        wandb.log({"train loss": mean_loss})
    return mean_loss


@torch.inference_mode()
def valid_step(model, dataloader, should_be_neg, should_be_pos):
    model.eval()
    losses = []
    scores = []
    for batch in tqdm.tqdm(dataloader):
        x, y_true = batch['tokens'], batch['labels']
        x = x.to(Config.device)
        y_true = y_true.to(Config.device)
        y_pred = model(x)

        loss = criterion(y_pred, y_true)
        losses.append(loss.item())

        y_true = y_true.cpu().numpy()
        y_pred = y_pred.cpu().numpy()
        y_pred[y_pred >= .5] = 1
        y_pred[y_pred < .5] = 0
        score = (y_pred == y_true)
        scores.append(score)

        del batch
        del x
        del y_true
        del y_pred
        torch.cuda.empty_cache()

    def append_to_scores(df: pd.DataFrame, expected_label: int):
        arr = np.array(list([b] for b in df['label'] == expected_label))
        if len(arr):
            scores.append(arr)
            
    append_to_scores(should_be_neg, 0)
    append_to_scores(should_be_pos, 1)

    scores = np.vstack(scores)
    accuracy = np.array(scores).mean()
    mean_loss = np.array(losses).mean()
    if Config.use_wandb:
        wandb.log({"valid loss": mean_loss})
        wandb.log({"valid accuracy": accuracy})
    return mean_loss, accuracy

In [48]:
@dataclass
class Config:
    # Name
    comment: str = "cool name for charts"
    
    # Stats
    use_wandb: bool = False
    
    # Saving model and info
    records_filename: str = "./work_files/data_exploration/records.csv"
    model_filename_format: str = "./work_files/data_exploration/models/model{:02}.pkl"
    config_filename_format: str = "./work_files/data_exploration/configs/config{:02}.pkl"

    # Data
    no_of_samples: int = 200000
    validation_size: float = 0.05
    train_file_neg: str = 'data/train_neg.txt'
    train_file_pos: str = 'data/train_pos.txt'
    test_file: str = 'data/test_data.txt'
    
    # Pre-trained model
    model_name: str = 'albert-base-v2'
    
    # Hyperparameters
    epochs: int = 5
    batch_size: int = 20
    learning_rate: float = 1e-4
    weight_decay = float = 1e-4
    scheduler_step_size: int = 5
    scheduler_gamma: float = 0.5

    # LoRA
    lora_r: int = 16
    lora_alpha: int = 32
    lora_target_modules: [str] = field(default_factory=lambda: ["query", "value"])
    lora_dropout: float = 0.5
    lora_bias: Literal["none", "all", "lora_only"] = "lora_only"

    # Head model
    head_model_str: str = field(default_factory=lambda: str(head_model))
    last_layer_size = 64
    
    # Other stuff
    device: str = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    random_seed: int = 42


embedding_model = AutoModel.from_pretrained(Config.model_name, resume_download=None)
head_model = nn.Sequential(nn.Linear(embedding_model.config.hidden_size, Config.last_layer_size),
                           nn.ReLU(),
                           nn.Linear(Config.last_layer_size, 1))

lora_config = LoraConfig(
    r=Config.lora_r,
    lora_alpha=Config.lora_alpha,
    target_modules=Config().lora_target_modules,
    lora_dropout=Config.lora_dropout,
    bias=Config.lora_bias,
)
lora_model = get_peft_model(embedding_model, lora_config)
my_model = MyModel(lora_model, head_model).to(Config.device)

criterion = nn.BCELoss()
optimizer = AdamW(my_model.parameters(), lr=Config.learning_rate, weight_decay=Config.weight_decay)
scheduler = StepLR(optimizer, step_size=Config.scheduler_step_size, gamma=Config.scheduler_gamma)

In [35]:
if Config.use_wandb:
    wandb.init(
        project="CIL Project",
        name=Config().comment,
        config=Config().__dict__
    )
    wandb.watch(my_model, log_freq=100)

In [36]:
def load_train_data(input_parsing_funs: [Callable[[str], bool]] = None) -> pd.DataFrame:
    tweets_set = set()
    tweets, labels = [], []

    def load_tweets(filename, label):
        with open(filename, 'r', encoding='utf-8') as f:
            count = Config.no_of_samples // 2
            for line in tqdm.tqdm(islice(f, count), total=count, desc='Loading Tweets'):
                line = line.rstrip()
                for fun in input_parsing_funs:
                    line = fun(line)
                if line not in tweets_set:
                    tweets_set.add(line)
                    tweets.append(line)
                    labels.append(label)

    load_tweets(Config.train_file_neg, 0)
    load_tweets(Config.train_file_pos, 1)

    return pd.DataFrame(data={'tweet': tweets, 'label': labels})


class InputParsing:
    @staticmethod
    def remove_users(words):
        return ' '.join([word for word in words.split() if not word == '<user>'])

    @staticmethod
    def remove_hashtags(words):
        return ' '.join([word for word in words.split() if not word.startswith('#')])

    @staticmethod
    def unify_hashtags(words):
        f = lambda word: '<hashtag>' if word.startswith('#') else word
        return ' '.join([f(word) for word in words.split()])

    @staticmethod
    def unify_numbers(words):
        f = lambda word: '<number>' if word.isnumeric() else word
        return ' '.join([f(word) for word in words.split()])

In [37]:
class TweetDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, index):
        return self.dataframe.iloc[index]

In [38]:
def split_dataset(dataset):
    valid_size = int(Config.validation_size * len(dataset))
    train_size = len(dataset) - valid_size
    generator = torch.Generator().manual_seed(Config.random_seed)
    train_split, valid_split = random_split(dataset, [train_size, valid_size], generator=generator)
    return dataset.dataframe.iloc[train_split.indices], dataset.dataframe.iloc[valid_split.indices]


def get_dataframes(input_parsing_funs: [Callable[[str], bool]] = None):
    return split_dataset(TweetDataset(load_train_data(input_parsing_funs)))

In [39]:
def split_by_filter(
        df: pd.DataFrame, 
        filter_fun: Callable[[str], bool]) -> (pd.DataFrame, pd.DataFrame):
    filtered = df['tweet'].apply(filter_fun)
    return df[filtered == True], df[filtered == False]


def split_into_neg_unknown_pos(
        df: pd.DataFrame, 
        neg_filter: Callable[[str], bool],
        pos_filter: Callable[[str], bool]) -> (pd.DataFrame, pd.DataFrame, pd.DataFrame):
    pos, rest = split_by_filter(df, pos_filter)
    neg, unknown = split_by_filter(rest, neg_filter)
    return neg, unknown, pos


class TweetFilters:
    @staticmethod
    def unclosed_parenthesis(tweet):
        return tweet.count('(') > tweet.count(')')

    @staticmethod
    # EXAMPLE FILTER FUNCTION
    def has_word_frame(tweet):
        return 'frame' in tweet

    @staticmethod
    # EXAMPLE FILTER FUNCTION
    def has_word_thanks(tweet):
        return 'thanks' in tweet
    
    @staticmethod
    def no_filter(tweet):
        return False

In [40]:
input_parsing_funs = [
    InputParsing.unify_hashtags,
]
neg_filter = TweetFilters.has_word_frame
pos_filter = TweetFilters.has_word_thanks

train_df, valid_df = get_dataframes(input_parsing_funs)
_, train_df, _ = split_into_neg_unknown_pos(train_df, neg_filter, pos_filter)
neg_valid, valid_df, pos_valid = split_into_neg_unknown_pos(valid_df, neg_filter, pos_filter)

train_dataset = TweetDataset(train_df)
valid_dataset = TweetDataset(valid_df)

Loading Tweets: 100%|██████████| 100000/100000 [00:00<00:00, 154393.53it/s]
Loading Tweets: 100%|██████████| 100000/100000 [00:00<00:00, 175535.31it/s]


In [41]:
def gen_tokenize_fun():
    tokenizer = AutoTokenizer.from_pretrained(Config.model_name, resume_download=None)

    def tokenize(data):
        tweets = [x["tweet"] for x in data]
        labels = Tensor([[x["label"]] for x in data])
        output = tokenizer(tweets, truncation=True, padding=True, return_tensors="pt")
        return {"tokens": output, "labels": labels}

    return tokenize


def make_dataloader(dataset, shuffle: bool):
    return DataLoader(dataset=dataset,
                      collate_fn=gen_tokenize_fun(),
                      batch_size=Config.batch_size,
                      shuffle=shuffle,
                      pin_memory=True)


def get_dataloaders(train_dataset, valid_dataset):
    return (make_dataloader(train_dataset, shuffle=True),
            make_dataloader(valid_dataset, shuffle=False))

In [42]:
train_loader, valid_loader = get_dataloaders(train_dataset, valid_dataset)

for epoch in range(Config.epochs):
    print(f"Epoch {epoch + 1}/{Config.epochs}:")

    train_loss = train_step(my_model, train_loader)
    valid_loss, valid_accuracy = valid_step(my_model, valid_loader, neg_valid, pos_valid)

    print(f"  TRAIN loss     = {train_loss}")
    print(f"  VALID loss     = {valid_loss}")
    print(f"  VALID accuracy = {valid_accuracy}")
    print(f"--------------------------------------------------------")

    scheduler.step()

Epoch 1/5:


100%|██████████| 8310/8310 [13:35<00:00, 10.19it/s]
100%|██████████| 438/438 [00:18<00:00, 23.90it/s]


  TRAIN loss     = 0.4153375732455802
  VALID loss     = 0.38497350065539415
  VALID accuracy = 0.8289749531060355
--------------------------------------------------------
Epoch 2/5:


100%|██████████| 8310/8310 [13:38<00:00, 10.15it/s]
100%|██████████| 438/438 [00:18<00:00, 23.66it/s]


  TRAIN loss     = 0.3578870075238167
  VALID loss     = 0.34439962332617474
  VALID accuracy = 0.8445327154363897
--------------------------------------------------------
Epoch 3/5:


100%|██████████| 8310/8310 [14:15<00:00,  9.71it/s]
100%|██████████| 438/438 [00:20<00:00, 21.43it/s]


  TRAIN loss     = 0.33691375745789887
  VALID loss     = 0.3334837592032538
  VALID accuracy = 0.8508220236124904
--------------------------------------------------------
Epoch 4/5:


100%|██████████| 8310/8310 [14:44<00:00,  9.40it/s]
100%|██████████| 438/438 [00:19<00:00, 22.12it/s]


  TRAIN loss     = 0.3243768064652611
  VALID loss     = 0.3314586176396641
  VALID accuracy = 0.8534701533708485
--------------------------------------------------------
Epoch 5/5:


100%|██████████| 8310/8310 [14:37<00:00,  9.47it/s]
100%|██████████| 438/438 [00:19<00:00, 22.34it/s]


  TRAIN loss     = 0.3156106497070312
  VALID loss     = 0.32689556781389667
  VALID accuracy = 0.854683879510096
--------------------------------------------------------


In [43]:
def save_model(model, accuracy: float, description: str = ''):
    records = pd.read_csv(Config.records_filename, index_col='Index')
    index = len(records)
    model_filename = Config.model_filename_format.format(index)
    config_filename = Config.config_filename_format.format(index)

    records.loc[model_filename] = {'Accuracy': accuracy, 'Description': description}
    pickle.dump(model, open(model_filename, 'wb'))
    pickle.dump(Config(), open(config_filename, 'wb'))
    records.to_csv(Config.records_filename)


def load_model(index):
    model_filename = Config.model_filename_format.format(index)
    return pickle.load(open(model_filename, 'rb'))


def load_config(index):
    config_filename = Config.config_filename_format.format(index)
    return pickle.load(open(config_filename, 'rb'))


def load_records():
    return pd.read_csv(Config.records_filename, index_col='Index')

In [51]:
save_model(my_model, valid_accuracy, Config.comment)

In [52]:
load_records()

Unnamed: 0_level_0,Accuracy,Description
Index,Unnamed: 1_level_1,Unnamed: 2_level_1
models/model00.sav,0.86605,base model
models/model01.pkl,0.841,small test model
models/model02.pkl,0.79887,small test - no duplicates
models/model03.pkl,0.854866,model 100k no hashtags
models/model04.pkl,0.854866,model 100k no hashtags
models/model05.pkl,0.850457,model 100k yes hashtags
models/model06.pkl,0.850457,model 100k yes hashtags
models/model07.pkl,0.849565,model 100k - unclosed parenthesis always neg
models/model08.pkl,0.859129,model 200k - unclosed parenthesis always neg
models/model09.pkl,0.859901,model 200k


In [53]:
if Config.use_wandb:
    wandb.finish()