In [None]:
ENABLE_WANDB = True
TRAIN_COLAB = True

if TRAIN_COLAB:
    !git clone https://github.com/elenanespolo/Sentiment_Sarcasm_Analysis

In [None]:
# update Colab folder after a push in the repository

%cd Sentiment_Sarcasm_Analysis
!git pull
%cd ..

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import Dataset
import os
import transformers
import tqdm

if TRAIN_COLAB:
    from Sentiment_Sarcasm_Analysis.dataset.besstie import dataset_besstie
    root_folder = "Sentiment_Sarcasm_Analysis/dataset/besstie/"
else:
    from dataset.besstie import dataset_besstie
    root_folder = "dataset/besstie/"

# NOTE: select here the split to use
# splits = {'train': 'train.csv', 'validation': 'valid.csv'}
splits = {'train': 'train_SS.csv', 'validation': 'valid_SS.csv'}
if not os.path.exists(root_folder):
    os.makedirs(root_folder)
if not os.path.exists(os.path.join(root_folder, splits["train"])) or not os.path.exists(os.path.join(root_folder, splits["validation"])):
    print("Downloading BESSTIE dataset...")
    # Login using e.g. `huggingface-cli login` to access this dataset
    df = pd.read_csv("hf://datasets/unswnlporg/BESSTIE/" + splits["train"])
    df.to_csv(os.path.join(root_folder, splits["train"]), index=False)
    df = pd.read_csv("hf://datasets/unswnlporg/BESSTIE/" + splits["validation"])
    df.to_csv(os.path.join(root_folder, splits["validation"]), index=False)


In [None]:
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    import random
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


In [None]:
#TODO: add ability to choose validation subcategory of the dataset
dataset_CFG = {
    'dataset_name': 'BESSTIE',
    'task': 'Sentiment-Sarcasm',
    'variety': 'en-UK',
    'source': 'Reddit',
    'classes': ['0', '1']
}
CFG = {
    'lr': 2e-5,
    'start_epoch': 20,
    'epochs': 30,
    'batch_size': 8,
    'max_length': 200,
    'min_length': 1,
    **dataset_CFG,
    'model_name': 'bert-base-uncased',
    'classification_head': 'cross_talk_conv', # 'linear' or 'conv' or 'lstm' or 'multi_task_conv' or 'cross_talk_conv'
    'seed': 0,
}

df_train = pd.read_csv(os.path.join(root_folder, splits['train']))
if dataset_CFG['task'] == 'Sentiment-Sarcasm':
    labels_count = pd.concat([df_train['sarcasm'].value_counts().sort_index(),df_train['sentiment'].value_counts().sort_index()], axis=1).set_axis(labels=['sarcasm', 'sentiment'], axis=1)
    !git clone https://github.com/WeiChengTseng/Pytorch-PCGrad
    !mv Pytorch-PCGrad pcgrad_repo
    from pcgrad_repo.pcgrad import PCGrad
else:
    labels_count = df_train["label"].value_counts().sort_index()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(labels_count)
print("Using device:", device)
set_seed(CFG['seed'])

   sarcasm  sentiment
0     2315       2361
1      808        762
Using device: cuda


# Model

In [None]:
class MultiKernelConvHead(torch.nn.Module):
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        num_labels: int,
        kernel_sizes=(2, 3, 5),
        dropout=0.1
    ):
        super().__init__()

        self.convs = torch.nn.ModuleList([
            torch.nn.Conv1d(
                in_channels=input_size,
                out_channels=hidden_size,
                kernel_size=k,
                padding=k // 2
            )
            for k in kernel_sizes
        ])

        self.activation = torch.nn.ReLU()
        self.pool = torch.nn.AdaptiveAvgPool1d(1)
        self.dropout = torch.nn.Dropout(dropout)

        self.classifier = torch.nn.Linear(
            hidden_size * len(kernel_sizes),
            num_labels
        )

    def forward(self, x):
        # x: (B, H, L)
        conv_outputs = []

        for conv in self.convs:
            h = self.activation(conv(x))      # (B, C, L)
            h = self.pool(h).squeeze(-1)       # (B, C)
            conv_outputs.append(h)

        x = torch.cat(conv_outputs, dim=1)    # (B, C * num_kernels)
        x = self.dropout(x)
        logits = self.classifier(x)

        return logits

class ConvClassificationHead(torch.nn.Module):
    def __init__(self, input_size: int, hidden_size: int, num_labels: int, linear=True):
        super().__init__()

        if linear:
            self.conv = torch.nn.Sequential(
                torch.nn.Conv1d(
                    in_channels=input_size,
                    out_channels=hidden_size,
                    kernel_size=3,
                    padding=1
                ),
                torch.nn.ReLU(),
                torch.nn.AdaptiveAvgPool1d(1),  # (B, hidden_size, 1)
                torch.nn.Flatten(),             # (B, hidden_size)
                torch.nn.Linear(hidden_size, num_labels)
            )
        else:
            self.conv = torch.nn.Sequential(
                torch.nn.Conv1d(
                    in_channels=input_size,
                    out_channels=hidden_size,
                    kernel_size=3,
                    padding=1
                ),
                torch.nn.ReLU(),
                torch.nn.AdaptiveAvgPool1d(1),  # (B, hidden_size, 1)
                torch.nn.Flatten()              # (B, hidden_size)
            )

    def forward(self, x):
        return self.conv(x)

class MultiTaskConvHead(torch.nn.Module):
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        num_sentiment_labels: int,
        num_sarcasm_labels: int
    ):
        super().__init__()

        self.sentiment_head = ConvClassificationHead(
            input_size=input_size,
            hidden_size=hidden_size,
            num_labels=num_sentiment_labels
        )

        self.sarcasm_head = ConvClassificationHead(
            input_size=input_size,
            hidden_size=hidden_size,
            num_labels=num_sarcasm_labels
        )

    def forward(self, sequence_output):
        """
        sequence_output: last_hidden_state from BERT
        shape: (batch, seq_len, hidden_size)
        """

        sentiment_logits = self.sentiment_head(sequence_output)
        sarcasm_logits = self.sarcasm_head(sequence_output)

        return {
            "sentiment": sentiment_logits,
            "sarcasm": sarcasm_logits
        }

class CrossTalkHead(torch.nn.Module):
    def __init__(
        self,
        input_size,
        conv_hidden_size,
        num_sentiment_labels,
        num_sarcasm_labels
    ):
        super().__init__()

        self.encoder = ConvClassificationHead(
            input_size=input_size,
            hidden_size=conv_hidden_size,
            num_labels = 0,
            linear = False
        )

        # task-specific embeddings
        self.sentiment_embed = torch.nn.Linear(
            conv_hidden_size, conv_hidden_size
        )
        self.sarcasm_embed = torch.nn.Linear(
            conv_hidden_size, conv_hidden_size
        )

        # cross-talk layers
        self.sentiment_fuse = torch.nn.Linear(
            2 * conv_hidden_size, conv_hidden_size
        )
        self.sarcasm_fuse = torch.nn.Linear(
            2 * conv_hidden_size, conv_hidden_size
        )

        self.sentiment_out = torch.nn.Linear(
            conv_hidden_size, num_sentiment_labels
        )
        self.sarcasm_out = torch.nn.Linear(
            conv_hidden_size, num_sarcasm_labels
        )

    def forward(self, sequence_output):
        shared = self.encoder(sequence_output)

        # first linear layer
        sent_feat = self.sentiment_embed(shared)
        sarc_feat = self.sarcasm_embed(shared)

        # cross-talk
        sent_feat = self.sentiment_fuse(
            torch.cat([sent_feat, sarc_feat], dim=-1)
        )
        sarc_feat = self.sarcasm_fuse(
            torch.cat([sarc_feat, sent_feat], dim=-1)
        )

        return {
            "sentiment": self.sentiment_out(sent_feat),
            "sarcasm": self.sarcasm_out(sarc_feat)
        }

In [None]:
def get_tokenizer_and_model(model_name:str):
    tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
    model = transformers.AutoModel.from_pretrained(model_name)
    return tokenizer, model

def get_classification_head(method: str, input_size:int, hidden_size: int, num_labels: int):
    if method == "linear":
        return torch.nn.Linear(input_size, num_labels)
    elif method == "conv":
        return ConvClassificationHead(input_size, hidden_size, num_labels)
    elif method == "lstm":
        return torch.nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=1,
            batch_first=True,
            bidirectional=True
        )
    elif method == "multi_conv":
        return MultiKernelConvHead(
            input_size=input_size,
            hidden_size=hidden_size,
            num_labels=num_labels,
            kernel_sizes=(2, 3, 5),
            num_channels=128,
            dropout=0.1
        )
    elif method == 'multi_task_conv':
        return MultiTaskConvHead(input_size, hidden_size, num_labels, num_labels)
    elif method == 'cross_talk_conv':
        return CrossTalkHead(input_size, hidden_size, num_labels, num_labels)
    else:
        raise ValueError(f"Unknown classification head method: {method}")


class MyClassifier(torch.nn.Module):
    def __init__(self, base_model_name, classification_head_name, num_labels):
        super().__init__()

        self.tokenizer, self.base_model = get_tokenizer_and_model(base_model_name)
        self.hidden_size = self.base_model.config.hidden_size
        self.dropout = torch.nn.Dropout(self.base_model.config.hidden_dropout_prob)

        self.classification_head_name = classification_head_name

        self.classification_head = get_classification_head(
            classification_head_name, self.hidden_size, self.hidden_size, num_labels
        )

        if classification_head_name == "lstm":
            self.output_layer = torch.nn.Linear(self.hidden_size*2, num_labels)

    def get_tokenizer(self) -> transformers.PreTrainedTokenizer:
        return self.tokenizer

    def forward(self, inputs):
        outputs = self.base_model(**inputs)
        sequence = self.dropout(outputs.last_hidden_state)

        if self.classification_head_name == "linear":
            cls_rep = sequence[:, 0, :]
            logits = self.classification_head(cls_rep)

        elif self.classification_head_name == "conv":
            # x: (batch, seq_len, hidden_size)
            x = sequence.transpose(1, 2)  # -> (batch, hidden_size, seq_len)
            logits = self.classification_head(x)

        elif self.classification_head_name == "lstm":
            lstm_out, _ = self.classification_head(sequence)
            cls_rep = lstm_out[:, 0, :]
            logits = self.output_layer(cls_rep)

        elif self.classification_head_name == 'multi_conv':
            ## TO-DO: implement
            logits = None

        elif self.classification_head_name == 'multi_task_conv':
            ## TO-DO: implement
            logits = None

        elif self.classification_head_name == 'cross_talk_conv':
            x = sequence.transpose(1, 2)
            logits = self.classification_head(x)

        return logits


# Train

In [None]:
def train_SS(model, train_loader, optimizer, criterion, device):
    model.train()

    train_sarc_loss = 0.0
    train_sent_loss = 0.0
    train_sarc_acc = 0.0
    train_sent_acc = 0.0
    c1, c2 = criterion

    pbar = tqdm.tqdm(train_loader)
    for batch in pbar:
        inputs = {
            'input_ids': batch['input_ids'].to(device),
            'attention_mask': batch['attention_mask'].to(device)
        }

        local_labels = batch['label'].to(device)
        outputs = model(inputs)

        sarc_loss = c1(outputs['sarcasm'], local_labels[:,0])
        sent_loss = c2(outputs['sentiment'], local_labels[:,1])

        optimizer.pc_backward([sarc_loss, sent_loss])
        optimizer.step()
        optimizer.zero_grad()

        train_sarc_loss += sarc_loss.item()
        train_sent_loss += sent_loss.item()

        _, preds_sarc = torch.max(outputs['sarcasm'], dim=1)
        _, preds_sent = torch.max(outputs['sentiment'], dim=1)
        train_sarc_acc += torch.sum(preds_sarc == local_labels[:,0]).item()
        train_sent_acc += torch.sum(preds_sent == local_labels[:,1]).item()
    
    return train_sarc_loss / len(train_loader), train_sent_loss / len(train_loader), train_sarc_acc / (len(train_loader.dataset)), train_sent_acc / (len(train_loader.dataset))


In [None]:
def train(model, train_loader, optimizer, criterion, device):
    model.train()

    train_loss = 0.0
    train_acc = 0.0

    pbar = tqdm.tqdm(train_loader)
    for batch in pbar:
        inputs = {
            'input_ids': batch['input_ids'].to(device),
            'attention_mask': batch['attention_mask'].to(device)
        }

        local_labels = batch['label'].to(device)
        outputs = model(inputs)

        loss = criterion(outputs, local_labels)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        train_loss += loss.item()

        _, preds = torch.max(outputs, dim=1)
        train_acc += torch.sum(preds == local_labels).item()
        
    return train_loss / len(train_loader), train_acc / (len(train_loader.dataset))

# Validation

In [None]:
def validate_SS(model, val_loader, criterion, device):
    model.eval()
    val_sarc_acc = 0.0
    val_sent_acc = 0.0
    val_sarc_loss = 0.0
    val_sent_loss = 0.0

    with torch.no_grad():
        for batch in val_loader:
            inputs = {
                'input_ids': batch['input_ids'].to(device),
                'attention_mask': batch['attention_mask'].to(device)
            }
            local_labels = batch['label'].to(device)
            outputs = model(inputs)

            sarcasm_criterion, sentiment_criterion = criterion
            sarc_loss = sarcasm_criterion(outputs['sarcasm'], local_labels[:,0])
            sent_loss = sentiment_criterion(outputs['sentiment'], local_labels[:,1])

            _, preds_sarc = torch.max(outputs['sarcasm'], dim=1)
            _, preds_sent = torch.max(outputs['sentiment'], dim=1)

            val_sarc_acc += torch.sum(preds_sarc == local_labels[:,0]).item()
            val_sent_acc += torch.sum(preds_sent == local_labels[:,1]).item()

            val_sarc_loss += sarc_loss.item()
            val_sent_loss += sent_loss.item()

    return val_sarc_loss / len(val_loader), val_sarc_acc / len(val_loader.dataset), val_sent_loss / len(val_loader), val_sent_acc / len(val_loader.dataset)

In [None]:
def validate(model, val_loader, criterion, device):
    model.eval()
    val_acc = 0.0
    val_loss = 0.0

    with torch.no_grad():
        for batch in val_loader:
            inputs = {
                'input_ids': batch['input_ids'].to(device),
                'attention_mask': batch['attention_mask'].to(device)
            }
            local_labels = batch['label'].to(device)
            outputs = model(inputs)

            loss = criterion(outputs, local_labels)
            _, preds = torch.max(outputs, dim=1)
            val_acc += torch.sum(preds == local_labels).item()
            val_loss += loss.item()

    return val_loss / len(val_loader), val_acc / len(val_loader.dataset)

# Wandb

In [None]:
if ENABLE_WANDB:
    import wandb
    # NOTE: set run_name
    run_name = "CrossTalk"
    run_id = f"{run_name}_{CFG['model_name']}_{CFG['classification_head']}_{CFG['dataset_name']}_{dataset_CFG['variety']}"
    # run_name = None

    run = wandb.init(
        entity="elena-nespolo02-politecnico-di-torino",
        project="Figurative Analysis",
        name=run_name,
        id=run_id,
        resume="allow",
        config=CFG,
        tags=[CFG['dataset_name'], CFG['task'], CFG['model_name']]
    )

    wandb.define_metric("epoch/step")
    wandb.define_metric("epoch/*", step_metric="epoch/step")

    wandb.define_metric("train/step")
    wandb.define_metric("train/*", step_metric="train/step")

    wandb.define_metric("validate/step")
    wandb.define_metric("validate/*", step_metric="validate/step")



# ML loop

In [None]:
models_root_dir = "./models"
!rm -rf {models_root_dir}
!mkdir {models_root_dir}

model_name = CFG['model_name']

tokenizer, model = get_tokenizer_and_model(model_name)

tokenizer = transformers.BertTokenizer.from_pretrained(model_name)

# load classifier model
# model = transformers.BertForSequenceClassification.from_pretrained(
#     model_name,
#     num_labels=2
# ).to(device)

model = MyClassifier(
    base_model_name=model_name,
    classification_head_name=CFG['classification_head'],
    num_labels=2
).to(device)

train_ds = dataset_besstie.BesstieDataSet(
    root_folder=root_folder,
    file_name=splits['train'],
    classes=dataset_CFG['classes'],
    tokenizer=tokenizer,
    min_length=CFG['min_length'],
    max_length=CFG['max_length'],
    variety=CFG['variety'],
    source=CFG['source'],
    task=CFG['task']
)

val_ds = dataset_besstie.BesstieDataSet(
    root_folder=root_folder,
    file_name=splits['validation'],
    classes=dataset_CFG['classes'],
    tokenizer=tokenizer,
    min_length=CFG['min_length'],
    max_length=CFG['max_length'],
    variety=CFG['variety'],
    source=CFG['source'],
    task=CFG['task']
)

if dataset_CFG['task'] == 'Sentiment-Sarcasm':
    optimizer = PCGrad(torch.optim.Adam(model.parameters()))
    sarcasm_criterion = torch.nn.CrossEntropyLoss(
        weight=torch.tensor(labels_count['sarcasm'].values/sum(labels_count['sarcasm']), dtype=torch.float).to(device)
    )
    sentiment_criterion = torch.nn.CrossEntropyLoss(
        weight=torch.tensor(labels_count['sentiment'].values/sum(labels_count['sentiment']), dtype=torch.float).to(device)
    )
    criterion = [sarcasm_criterion, sentiment_criterion] # a list of per-task losses
else:
    optimizer = torch.optim.Adam(model.parameters(), lr=CFG['lr'])
    criterion = torch.nn.CrossEntropyLoss(
        weight=torch.tensor(labels_count.values/sum(labels_count), dtype=torch.float).to(device)
    )

train_loader = torch.utils.data.DataLoader(
    train_ds,
    batch_size=CFG['batch_size'],
    shuffle=True
)

val_loader = torch.utils.data.DataLoader(
    val_ds,
    batch_size=CFG['batch_size'],
    shuffle=False
)

#TODO: gradient accumulation to reduce memory usage?
# accumulation_steps = 4  # Effective batch size = batch_size * accumulation_steps
# for i, batch in enumerate(train_dataloader):
#     outputs = model(**batch)
#     loss = outputs.loss / accumulation_steps
#     loss.backward()
#     if (i + 1) % accumulation_steps == 0:
#         torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
#         optimizer.step()
#         scheduler.step()
#         model.zero_grad()

# Loading form a starting point
if CFG['start_epoch'] > 0 and ENABLE_WANDB:
    artifact = run.use_artifact(f'elena-nespolo02-politecnico-di-torino/Figurative Analysis/{run_id}:epoch_{CFG['start_epoch']}', type='model')
    artifact_dir = artifact.download()

    artifact_path = os.path.join(artifact_dir, run_id+f"_epoch_{CFG['start_epoch']}.pth")

    checkpoint = torch.load(artifact_path, map_location=device)

    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])



In [None]:
for epoch in range(CFG['start_epoch']+1,CFG['epochs']+1):

    print(f"Epoch {epoch}/{CFG['epochs']}")

    if dataset_CFG['task'] == 'Sentiment-Sarcasm':
        epoch_sarc_loss, epoch_sent_loss, epoch_sarc_acc, epoch_sent_acc = train_SS(model, train_loader, optimizer, criterion, device)

        val_sarc_loss, val_sarc_acc, val_sent_loss, val_sent_acc = validate_SS(model, val_loader, criterion, device)

        if ENABLE_WANDB:
            run.log({
                    "epoch/step": epoch,
                    "epoch/train_sarc_loss": epoch_sarc_loss,
                    "epoch/train_sent_loss": epoch_sent_loss,
                    "epoch/train_sarc_acc": epoch_sarc_acc,
                    "epoch/train_sent_acc": epoch_sent_acc,
                    "epoch/val_sarc_loss": val_sarc_loss,
                    "epoch/val_sent_loss": val_sent_loss,
                    "epoch/val_sarc_acc": val_sarc_acc,
                    "epoch/val_sent_acc": val_sent_acc
                },
                commit=True,
            )
        print(f"Training Sarcasm Loss: {epoch_sarc_loss:.4f}")
        print(f"Training Sentiment Loss: {epoch_sent_loss:.4f}")
        print(f"Training Sarcasm Acc: {epoch_sarc_acc:.4f}")
        print(f"Training Sentiment Acc: {epoch_sent_acc:.4f}")
    else:
        epoch_loss, epoch_acc = train(model, train_loader, optimizer, criterion, device)

        val_loss, val_acc = validate(model, val_loader, criterion, device)

        if ENABLE_WANDB:
            run.log({
                    "epoch/step": epoch,
                    "epoch/train_loss": epoch_loss,
                    "epoch/train_acc": epoch_acc,
                    "epoch/val_loss": val_loss,
                    "epoch/val_acc": val_acc
                },
                commit=True,
            )

        print(f"Training Loss: {epoch_loss:.4f}")
        print(f"Training Acc: {epoch_acc:.4f}")

    if (epoch % 5) == 0 and ENABLE_WANDB:
        checkpoint = {
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.optimizer.state_dict(),
            "epoch/step": epoch
        }

        file_name = f"{run_id}_epoch_{epoch}.pth"

        # Saving the progress
        file_path = os.path.join(models_root_dir, file_name)
        torch.save(checkpoint, file_path)

        print(f"Model saved to {file_path}")

        artifact = wandb.Artifact(name=run_id, type="model")
        artifact.add_file(file_path)

        run.log_artifact(artifact, aliases=["latest", f"epoch_{epoch}"])

if ENABLE_WANDB:
    run.finish()

In [None]:
run.finish()