In [None]:
import numpy as np
import pandas as pd
import os

from utils.utils import plot, set_all_seeds
from utils.common import DataModule, Trainer

In [None]:
checkpoint = 'roberta-base'
task = ['empathy', 'wrong_empathy'] # empathy: LLM annotation, wrong_empathy: crowdsource annotation
# feature_to_tokenise=['demographic_essay', 'article']
# feature_to_tokenise=['demographic', 'essay']
feature_to_tokenise=['demographic_essay']

################# COMBINED TRAIN FILE ##############
# train_file = './data/WS22-WS23-sep-from-aug-train-gpt.tsv' # w/o augmentation
train_file = './data/WS22-WS23-augmented-train-gpt.tsv'

################# WASSA 2022 ####################
# dev_file = './data/WS22-dev-gpt.tsv'
# dev_label_crowd = './data/WASSA22/goldstandard_dev_2022.tsv'
# dev_label_gpt = './data/WS22-dev-gpt.tsv'

################# WASSA 2023 ####################
dev_file = './data/WS23-dev-gpt.tsv'
dev_label_crowd = './data/WASSA23/goldstandard_dev.tsv'
dev_label_gpt = './data/WS23-dev-gpt.tsv'

In [None]:
seed_range = [0, 42, 100, 999, 1234]
anno_diff_range = np.arange(0, 6.5, 0.5)

mode=0 # -1: crowd, 1: gpt, 0: crowd-gpt

In [None]:
val_results = pd.DataFrame()

for seed in seed_range:

    set_all_seeds(seed)
    
    data_module = DataModule(
        task=task,
        checkpoint=checkpoint,
        batch_size=16,
        feature_to_tokenise=feature_to_tokenise,
        seed=seed
    )
    
    train_loader = data_module.dataloader(file=train_file, send_label=True, shuffle=True)
    dev_loader = data_module.dataloader(file=dev_file, send_label=False, shuffle=False)

    for anno_diff in anno_diff_range:
        trainer = Trainer(
            task=task,
            checkpoint=checkpoint,
            lr=1e-5,
            n_epochs=10,
            train_loader=train_loader,
            dev_loader=dev_loader,
            dev_label_gpt=dev_label_gpt,
            dev_label_crowd=dev_label_crowd,
            device_id=0,
            anno_diff=anno_diff,
            mode=mode
        )

        ## If we want to save model to use while testing
        # save_as_loss = './ws23ckp/loss-llm-roberta-seed-' + str(seed) + '-anno_diff-' + str(anno_diff) + '.pth'
        # save_as_pearson = './ws23ckp/pearson-llm-roberta-seed-' + str(seed) + '-anno_diff-' + str(anno_diff) + '.pth'
        
        val_pearson_r = trainer.fit(save_as_loss=None, save_as_pearson=None, dev_alpha=True)

        # save as seed in index and anno_diff in columns
        print(f'\n----Seed {seed}, anno_diff {anno_diff}: {val_pearson_r}----\n')
        val_results.loc[seed, anno_diff] = val_pearson_r

    # Saving in each seed to be cautious
    # val_results.to_csv('ws23-val_results_diff_seed_anno_diff.tsv', sep='\t')

# Test

## WS 23

In [None]:
seed = 0
anno_diff = 5.0

test_file = './data/PREPROCESSED-WS23-test.tsv'
load_model = './ws23ckp/pearson-llm-roberta-seed-' + str(seed) + '-anno_diff-' + str(anno_diff) + '.pth'

## WS 22

In [None]:
# seed = 1234
# anno_diff = 6.0

# test_file = './data/PREPROCESSED-WS22-test.tsv'
# load_model = './ws22ckp/pearson-llm-roberta-seed-' + str(seed) + '-anno_diff-' + str(anno_diff) + '.pth'

## Let's test

In [None]:
set_all_seeds(seed)

data_module = DataModule(
    task=task,
    checkpoint=checkpoint,
    batch_size=16,
    feature_to_tokenise=feature_to_tokenise,
    seed=seed
)

print('Working with', test_file)
test_loader = data_module.dataloader(file=test_file, send_label=False, shuffle=False)

trainer = Trainer(
    task=task,
    checkpoint=checkpoint,
    lr=1e-5,
    n_epochs=10,
    train_loader=None,
    dev_loader=None,
    dev_label_gpt=None,
    dev_label_crowd=None,
    device_id=0,
    anno_diff=anno_diff,
    mode=0 # -1: crowd, 1: gpt, 0: crowd-gpt
)

In [None]:
print('Working with', load_model)
pred = trainer.evaluate(dataloader=test_loader, load_model=load_model)
pred_df = pd.DataFrame({'emp': pred, 'dis': pred}) # we're not predicting distress, just aligning with submission system
pred_df.to_csv('./tmp/predictions_EMP.tsv', sep='\t', index=None, header=None)

In [None]:
%cd tmp
!zip predictions.zip predictions_EMP.tsv

In [None]:
!rm predictions_EMP.tsv predictions.zip
%cd ../