In [17]:
import numpy as np
import pandas as pd
import os
import random

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, TensorDataset

from transformers import (
    DataCollatorWithPadding,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    get_linear_schedule_with_warmup
)
from datasets import Dataset

from sklearn.preprocessing import MinMaxScaler

os.chdir("/g/data/jr19/rh2942/text-empathy/")
from evaluation import pearsonr
from utils.utils import plot, get_device, set_all_seeds
from utils.common import EarlyStopper

In [18]:
os.environ['TOKENIZERS_PARALLELISM'] = 'false' # due to huggingface warning
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'

In [19]:
class Empathy(nn.Module):
    def __init__(self, checkpoint):
        super(Empathy, self).__init__()
        self.transformer = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=768)
        self.fc1 = nn.Sequential(
            nn.Linear(768, 512), nn.Tanh(), nn.Dropout(0.2)
        )
        self.fc2 = nn.Sequential(
            nn.Linear(512+5+1, 256), nn.Tanh(), nn.Dropout(0.2),
            nn.Linear(256, 1)
        )

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        gender=None,
        education=None,
        race=None,
        age=None,
        income=None,
        distress=None
    ):

        output = self.transformer(
            input_ids= input_ids,
            attention_mask=attention_mask,
        )

        output = self.fc1(output.logits)
        output = torch.cat([output, gender, education, race, age, income, distress], 1)
        output = self.fc2(output)
        return output

In [20]:
class Distress(nn.Module):
    def __init__(self, checkpoint):
        super(Distress, self).__init__()
        self.transformer = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=768)
        self.fc1 = nn.Sequential(
            nn.Linear(768, 512), nn.Tanh(), nn.Dropout(0.2)
        )
        self.fc2 = nn.Sequential(
            nn.Linear(512+5, 256), nn.Tanh(), nn.Dropout(0.2),
            nn.Linear(256, 1)
        )

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        gender=None,
        education=None,
        race=None,
        age=None,
        income=None
    ):

        output = self.transformer(
            input_ids= input_ids,
            attention_mask=attention_mask,
        )

        output = self.fc1(output.logits)
        output = torch.cat([output, gender, education, race, age, income], 1)
        output = self.fc2(output)
        return output

In [21]:
class DataModule:
    def __init__(self, task, checkpoint, batch_size, feature_to_tokenise, seed):

        self.task = task
        self.checkpoint = checkpoint
        self.batch_size = batch_size
        self.tokeniser = AutoTokenizer.from_pretrained(
            self.checkpoint,
            use_fast=True
        )
        self.data_collator = DataCollatorWithPadding(tokenizer=self.tokeniser)
        self.feature_to_tokenise = feature_to_tokenise # to tokenise function
        self.seed = seed

        assert len(self.task) == 2, 'task must be a list with two elements'
    
    def _process_raw(self, path, send_label):
        data = pd.read_csv(path, sep='\t')
    
        if send_label:
            text = data[self.feature_to_tokenise + self.task]
        else:
            text = data[self.feature_to_tokenise]

        demog = ['gender', 'education', 'race', 'age', 'income']
        data_demog = data[demog]
        scaler = MinMaxScaler()
        data_demog = pd.DataFrame(
            scaler.fit_transform(data_demog),
            columns=demog
        )
        data = pd.concat([text, data_demog], axis=1) 
        return data

    def _tokeniser_fn(self, sentence):
        if len(self.feature_to_tokenise) == 1: # only one feature
            return self.tokeniser(sentence[self.feature_to_tokenise[0]], truncation=True)
        # otherwise tokenise a pair of sentence
        return self.tokeniser(sentence[self.feature_to_tokenise[0]], sentence[self.feature_to_tokenise[1]], truncation=True)

    def _process_input(self, file, send_label):
        data = self._process_raw(path=file, send_label=send_label)
        data = Dataset.from_pandas(data, preserve_index=False) # convert to huggingface dataset
        data = data.map(self._tokeniser_fn, batched=True, remove_columns=self.feature_to_tokenise) # tokenise
        data = data.with_format('torch')
        return data

    # taken from https://pytorch.org/docs/stable/notes/randomness.html
    def _seed_worker(self, worker_id):
        worker_seed = torch.initial_seed() % 2**32
        np.random.seed(worker_seed)
        random.seed(worker_seed)     

    def dataloader(self, file, send_label, shuffle):
        data = self._process_input(file=file, send_label=send_label)

        # making sure the shuffling is reproducible
        g = torch.Generator()
        g.manual_seed(self.seed)
        
        return DataLoader(
            data,
            batch_size=self.batch_size,
            shuffle=shuffle,
            collate_fn=self.data_collator,
            num_workers=24,
            worker_init_fn=self._seed_worker,
            generator=g
        )

In [39]:
class Trainer:
    def __init__(self, task, model_distress, model_empathy, lr, n_epochs_distress, n_epochs_empathy, train_loader,
                 dev_loader, dev_label_file, device_id=0):
        self.device = get_device(device_id)
        self.task = task
        self.model_distress = model_distress.to(self.device)
        self.model_empathy = model_empathy.to(self.device)
        self.lr = lr
        self.n_epochs_distress = n_epochs_distress
        self.n_epochs_empathy = n_epochs_empathy
        self.train_loader = train_loader
        self.dev_loader = dev_loader
        self.dev_label_file = dev_label_file
        
        self.loss_fn = nn.MSELoss()
        self.optimiser_distress = torch.optim.AdamW(
            params=self.model_distress.parameters(),
            lr=self.lr,
            betas=(0.9, 0.98),
            eps=1e-06,
            weight_decay=0.1
        )
        
        self.optimiser_empathy = torch.optim.AdamW(
            params=self.model_empathy.parameters(),
            lr=self.lr,
            betas=(0.9, 0.98),
            eps=1e-06,
            weight_decay=0.1
        )

        n_training_step_distress = self.n_epochs_distress*len(self.train_loader)
        self.lr_scheduler_distress = get_linear_schedule_with_warmup(
            optimizer=self.optimiser_distress,
            num_warmup_steps=0.06*n_training_step_distress,
            num_training_steps=n_training_step_distress
        )
        
        n_training_step_empathy = self.n_epochs_empathy*len(self.train_loader)
        self.lr_scheduler_empathy = get_linear_schedule_with_warmup(
            optimizer=self.optimiser_empathy,
            num_warmup_steps=0.06*n_training_step_empathy,
            num_training_steps=n_training_step_empathy
        )
        
        self.best_pearson_r = -1.0 # initiliasation
        self.early_stopper_distress = EarlyStopper(patience=3, min_delta=0.01)
        self.early_stopper_empathy = EarlyStopper(patience=3, min_delta=0.01)
        
        assert len(self.task) == 2, 'task must be a list with two elements'
        assert self.task[0] == 'distress', 'First item of task list should be the first guide - distress'
    
    def _freeze_unfreeze(self, model, freeze=False):
        for param in model.parameters():
            param.requires_grad = not freeze # if freeze is required (True): requires_grad is False
            
    def _training_step_distress(self):
        tr_loss_distress = 0.0
        idx = 0
        guide = torch.empty((len(self.train_loader.dataset), 1), device=self.device)
        
        self.model_distress.train()
    
        for data in self.train_loader:
            input_ids = data['input_ids'].to(self.device, dtype=torch.long)
            attention_mask = data['attention_mask'].to(self.device, dtype=torch.long)
            
            distress = data[self.task[0]].to(self.device, dtype=torch.float).view(-1, 1)
            
            gender = data['gender'].to(self.device, dtype=torch.float).view(-1, 1)
            education = data['education'].to(self.device, dtype=torch.float).view(-1, 1)
            race = data['race'].to(self.device, dtype=torch.float).view(-1, 1)
            age = data['age'].to(self.device, dtype=torch.float).view(-1, 1)
            income = data['income'].to(self.device, dtype=torch.float).view(-1, 1)
    
            ### Training distress model
            self._freeze_unfreeze(model_distress, freeze=False)
            self._freeze_unfreeze(model_empathy, freeze=True)
            
            outputs_distress = self.model_distress(
                input_ids=input_ids,                 
                attention_mask=attention_mask,
                gender=gender,
                education=education,
                race=race,
                age=age,
                income=income
            )
            loss = self.loss_fn(outputs_distress, distress)
            tr_loss_distress += loss.item()

            self.optimiser_distress.zero_grad()
            loss.backward()
            self.optimiser_distress.step()
            self.lr_scheduler_distress.step()

            batch_size = outputs_distress.shape[0]
            guide[idx:idx+batch_size, :] = outputs_distress
            idx += batch_size
            
        print(f'Train loss (distress): {tr_loss_distress / len(train_loader)}')
        return guide.detach()

    def _training_step_empathy(self, guide):
        tr_loss_empathy = 0.0
        idx = 0

        self.model_empathy.train()
    
        for data in self.train_loader:
            input_ids = data['input_ids'].to(self.device, dtype=torch.long)
            attention_mask = data['attention_mask'].to(self.device, dtype=torch.long)
            
            empathy = data[self.task[1]].to(self.device, dtype=torch.float).view(-1, 1)
            
            gender = data['gender'].to(self.device, dtype=torch.float).view(-1, 1)
            education = data['education'].to(self.device, dtype=torch.float).view(-1, 1)
            race = data['race'].to(self.device, dtype=torch.float).view(-1, 1)
            age = data['age'].to(self.device, dtype=torch.float).view(-1, 1)
            income = data['income'].to(self.device, dtype=torch.float).view(-1, 1)

            batch_size = empathy.shape[0]
            batched_guide = guide[idx:idx+batch_size, :]
            idx += batch_size
            
            ### Training empathy model
            self._freeze_unfreeze(model_distress, freeze=True)
            self._freeze_unfreeze(model_empathy, freeze=False)
            
            outputs_empathy = self.model_empathy(
                input_ids=input_ids,                 
                attention_mask=attention_mask,
                gender=gender,
                education=education,
                race=race,
                age=age,
                income=income,
                distress=batched_guide
            )
            loss = self.loss_fn(outputs_empathy, empathy)
            tr_loss_empathy += loss.item()
    
            self.optimiser_empathy.zero_grad()
            loss.backward()
            self.optimiser_empathy.step()
            self.lr_scheduler_empathy.step()
            
        print(f'Train loss (empathy): {tr_loss_empathy / len(train_loader)}')

    def _fit_distress(self, save_model=False):
        print('--- Training distress model ---')
        dev_label = pd.read_csv(self.dev_label_file, sep='\t', header=None)
        true_distress = dev_label.iloc[:, 1].tolist()
        
        for epoch in range(self.n_epochs_distress):
            print(f'Epoch: {epoch+1}')
            guide = self._training_step_distress()

            (preds_distress, _) = self.evaluate(dataloader=self.dev_loader, load_model=False)
            
            pearson_r_distress = pearsonr(true_distress, preds_distress)
            print(f'Pearson r (distress): {pearson_r_distress}')
            
            val_loss_distress = self.loss_fn(torch.tensor(preds_distress), torch.tensor(true_distress))
            print('Validation loss (distress):', val_loss_distress.item())
            
            if self.early_stopper_distress.early_stop(val_loss_distress):
                break

            # if (pearson_r_empathy > self.best_pearson_r):
            #     self.best_pearson_r = pearson_r_empathy            
            #     if save_model:
            #         torch.save(self.model.state_dict(), 'EmpathGuRo.pth')
            #         print("Saved the model in epoch " + str(epoch+1))
            
            # print(f'Best dev set Pearson r (empathy): {self.best_pearson_r}\n')
            print()
        return guide

    def fit_empathy(self, save_model=False):
        dev_label = pd.read_csv(self.dev_label_file, sep='\t', header=None)
        true_empathy = dev_label.iloc[:, 0].tolist()
        
        guide = self._fit_distress()
        
        print('\n\n--- Training empathy model ---')
        
        for epoch in range(self.n_epochs_empathy):
            print(f'Epoch: {epoch+1}')
            self._training_step_empathy(guide)

            (preds_distress, preds_empathy) = self.evaluate(dataloader=self.dev_loader, load_model=False)
            
            pearson_r_empathy = pearsonr(true_empathy, preds_empathy)
            print(f'Pearson r (empathy): {pearson_r_empathy}')
            
            val_loss_empathy = self.loss_fn(torch.tensor(preds_empathy), torch.tensor(true_empathy))
            print('Validation loss (empathy):', val_loss_empathy.item())
            
            if self.early_stopper_empathy.early_stop(val_loss_empathy):
                break

            if (pearson_r_empathy > self.best_pearson_r):
                self.best_pearson_r = pearson_r_empathy            
                if save_model:
                    torch.save(self.model.state_dict(), 'EmpathGuRo.pth')
                    print("Saved the model in epoch " + str(epoch+1))
            
            print(f'Best dev set Pearson r (empathy): {self.best_pearson_r}\n')

    def evaluate(self, dataloader, load_model=False):
        if load_model:
            self.model.load_state_dict(torch.load('EmpathGuRo.pth'))
    
        pred_distress = torch.empty((len(dataloader.dataset), 1), device=self.device) # len(self.dev_loader.dataset) --> # of samples
        pred_empathy = torch.empty((len(dataloader.dataset), 1), device=self.device) # len(self.dev_loader.dataset) --> # of samples
        self.model_distress.eval()
        self.model_empathy.eval()
    
        with torch.no_grad():
            idx = 0
            for data in dataloader:
                input_ids = data['input_ids'].to(self.device, dtype=torch.long)
                attention_mask = data['attention_mask'].to(self.device, dtype=torch.long)
                gender = data['gender'].to(self.device, dtype=torch.float).view(-1, 1)
                education = data['education'].to(self.device, dtype=torch.float).view(-1, 1)
                race = data['race'].to(self.device, dtype=torch.float).view(-1, 1)
                age = data['age'].to(self.device, dtype=torch.float).view(-1, 1)
                income = data['income'].to(self.device, dtype=torch.float).view(-1, 1)
                # self.prev_empathy = self.prev_empathy.to(self.device, dtype=torch.float).view(-1, 1)
        
                outputs_distress = self.model_distress(
                    input_ids=input_ids,                 
                    attention_mask=attention_mask,
                    gender=gender,
                    education=education,
                    race=race,
                    age=age,
                    income=income
                )
                
                batch_size = outputs_distress.shape[0]
                pred_distress[idx:idx+batch_size, :] = outputs_distress
                
                outputs_empathy = self.model_empathy(
                    input_ids=input_ids,                 
                    attention_mask=attention_mask,
                    gender=gender,
                    education=education,
                    race=race,
                    age=age,
                    income=income,
                    distress=outputs_distress
                )
        
                pred_empathy[idx:idx+batch_size, :] = outputs_empathy
                
                idx += batch_size
            
        return ([float(k) for k in pred_distress], [float(k) for k in pred_empathy])

In [40]:
checkpoint = 'roberta-base'
task = ['distress', 'empathy'] #guide first
# feature_to_tokenise=['demographic_essay', 'article']
# feature_to_tokenise=['demographic', 'essay']
feature_to_tokenise=['demographic_essay']
seed = 0

# train_file = './data/essay-train-ws22-ws23.tsv'
train_file = './data/PREPROCESSED-WS22-WS23-train.tsv'
# train_file = './data/COMBINED-PREPROCESSED-PARAPHRASED-WS22-WS23-train.tsv'

# WASSA 2022
# dev_file = './data/PREPROCESSED-WS22-dev.tsv'
# dev_label_file = './data/WASSA22/goldstandard_dev_2022.tsv'
# test_file = './data/PREPROCESSED-WS22-test.tsv'

# WASSA 2023
dev_file = './data/PREPROCESSED-WS23-dev.tsv'
dev_label_file = './data/WASSA23/goldstandard_dev.tsv'
test_file = './data/PREPROCESSED-WS23-test.tsv'

In [41]:
set_all_seeds(seed)

data_module = DataModule(
    task=task,
    checkpoint=checkpoint,
    batch_size=16,
    feature_to_tokenise=feature_to_tokenise,
    seed=0
)

train_loader = data_module.dataloader(file=train_file, send_label=True, shuffle=True)
dev_loader = data_module.dataloader(file=dev_file, send_label=False, shuffle=False)
test_loader = data_module.dataloader(file=test_file, send_label=False, shuffle=False)

model_distress = Distress(checkpoint=checkpoint)
model_empathy = Empathy(checkpoint=checkpoint)

trainer = Trainer(
    task=task,
    model_distress=model_distress,
    model_empathy=model_empathy,
    lr=1e-5,
    n_epochs_distress=10,
    n_epochs_empathy=20,
    train_loader=train_loader,
    dev_loader=dev_loader,
    dev_label_file=dev_label_file,
    device_id=0
)

trainer.fit_empathy(save_model=False)

Map:   0%|          | 0/2636 [00:00<?, ? examples/s]

Map:   0%|          | 0/208 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight', 'classifier.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight', 'classifier.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


--- Training distress model ---
Epoch: 1
Train loss (distress): 9.668442368507385
Pearson r (distress): 0.068
Validation loss (distress): 3.095682382583618

Epoch: 2
Train loss (distress): 3.7979113867788605
Pearson r (distress): 0.289
Validation loss (distress): 3.021509885787964

Epoch: 3
Train loss (distress): 3.327391808683222
Pearson r (distress): 0.391
Validation loss (distress): 2.8054873943328857

Epoch: 4
Train loss (distress): 2.7452551621379273
Pearson r (distress): 0.459
Validation loss (distress): 2.6769156455993652

Epoch: 5
Train loss (distress): 2.219888780333779
Pearson r (distress): 0.499
Validation loss (distress): 2.472139835357666

Epoch: 6
Train loss (distress): 1.8720871527989706
Pearson r (distress): 0.514
Validation loss (distress): 2.7824440002441406

Epoch: 7
Train loss (distress): 1.5658171830755292
Pearson r (distress): 0.516
Validation loss (distress): 2.6947550773620605

Epoch: 8
Train loss (distress): 1.3129501481850943
Pearson r (distress): 0.494
Valida

# Test

In [92]:
pred = trainer.evaluate(dataloader=test_loader, load_model=True)
pred_df = pd.DataFrame({'emp': pred, 'dis': pred}) # we're not predicting distress, just aligning with submission system
pred_df.to_csv('./tmp/predictions_EMP.tsv', sep='\t', index=None, header=None)
pred_df

Unnamed: 0,emp,dis
0,5.240249,5.240249
1,4.962897,4.962897
2,5.524402,5.524402
3,5.303425,5.303425
4,5.142904,5.142904
...,...,...
95,4.535624,4.535624
96,4.815518,4.815518
97,3.978409,3.978409
98,5.163967,5.163967


# Extra

## Batch-level training

In [None]:
class Trainer:
    def __init__(self, task, model_distress, model_empathy, lr, n_epochs, train_loader,
                 dev_loader, dev_label_file, device_id=0):
        self.device = get_device(device_id)
        self.task = task
        self.model_distress = model_distress.to(self.device)
        self.model_empathy = model_empathy.to(self.device)
        self.lr = lr
        self.n_epochs = n_epochs
        self.train_loader = train_loader
        self.dev_loader = dev_loader
        self.dev_label_file = dev_label_file
        
        self.loss_fn = nn.MSELoss()
        self.optimiser_distress = torch.optim.AdamW(
            params=self.model_distress.parameters(),
            lr=self.lr,
            betas=(0.9, 0.98),
            eps=1e-06,
            weight_decay=0.1
        )
        
        self.optimiser_empathy = torch.optim.AdamW(
            params=self.model_empathy.parameters(),
            lr=self.lr,
            betas=(0.9, 0.98),
            eps=1e-06,
            weight_decay=0.1
        )

        n_training_step = self.n_epochs*len(self.train_loader)
        
        self.lr_scheduler_distress = get_linear_schedule_with_warmup(
            optimizer=self.optimiser_distress,
            num_warmup_steps=0.06*n_training_step,
            num_training_steps=n_training_step
        )
        
        self.lr_scheduler_empathy = get_linear_schedule_with_warmup(
            optimizer=self.optimiser_empathy,
            num_warmup_steps=0.06*n_training_step,
            num_training_steps=n_training_step
        )
        
        self.best_pearson_r = -1.0 # initiliasation
        self.early_stopper = EarlyStopper(patience=3, min_delta=0.01)
        # self.prev_empathy = torch.rand(dev_loader.batch_size, 1) * 6 + 1 # random initialisation between 1.0 to 7.0
        
        assert len(self.task) == 2, 'task must be a list with two elements'
        assert self.task[0] == 'distress', 'First item of task list should be the first guide - distress'
    
    def _freeze_unfreeze(self, model, freeze=False):
        for param in model.parameters():
            param.requires_grad = not freeze # if freeze is required (True): requires_grad is False
            
    def _training_step(self):
        tr_loss_distress = 0.0
        tr_loss_empathy = 0.0
        
        self.model_distress.train()
        self.model_empathy.train()
    
        for data in self.train_loader:
            input_ids = data['input_ids'].to(self.device, dtype=torch.long)
            attention_mask = data['attention_mask'].to(self.device, dtype=torch.long)
            
            distress = data[self.task[0]].to(self.device, dtype=torch.float).view(-1, 1)
            empathy = data[self.task[1]].to(self.device, dtype=torch.float).view(-1, 1)
            
            gender = data['gender'].to(self.device, dtype=torch.float).view(-1, 1)
            education = data['education'].to(self.device, dtype=torch.float).view(-1, 1)
            race = data['race'].to(self.device, dtype=torch.float).view(-1, 1)
            age = data['age'].to(self.device, dtype=torch.float).view(-1, 1)
            income = data['income'].to(self.device, dtype=torch.float).view(-1, 1)
    
            ### Training distress model
            self._freeze_unfreeze(model_distress, freeze=False)
            self._freeze_unfreeze(model_empathy, freeze=True)
            
            outputs_distress = self.model_distress(
                input_ids=input_ids,                 
                attention_mask=attention_mask,
                gender=gender,
                education=education,
                race=race,
                age=age,
                income=income,
                # empathy=empathy
            )
            loss = self.loss_fn(outputs_distress, distress)
            tr_loss_distress += loss.item()

            self.optimiser_distress.zero_grad()
            loss.backward()
            self.optimiser_distress.step()
            self.lr_scheduler_distress.step()
            
            ### Training empathy model
            self._freeze_unfreeze(model_distress, freeze=True)
            self._freeze_unfreeze(model_empathy, freeze=False)
            
            outputs_empathy = self.model_empathy(
                input_ids=input_ids,                 
                attention_mask=attention_mask,
                gender=gender,
                education=education,
                race=race,
                age=age,
                income=income,
                distress=outputs_distress.detach()
            )
            loss = self.loss_fn(outputs_empathy, empathy)
            tr_loss_empathy += loss.item()
    
            self.optimiser_empathy.zero_grad()
            loss.backward()
            self.optimiser_empathy.step()
            self.lr_scheduler_empathy.step()
            
        print(f'Train loss (distress): {tr_loss_distress / len(train_loader)}')
        print(f'Train loss (empathy): {tr_loss_empathy / len(train_loader)}')

    def fit(self, save_model=False):
        dev_label = pd.read_csv(self.dev_label_file, sep='\t', header=None)
        true_distress = dev_label.iloc[:, 1].tolist()
        true_empathy = dev_label.iloc[:, 0].tolist()
        
        for epoch in range(self.n_epochs):
            print(f'Epoch: {epoch+1}')
            self._training_step()

            (preds_distress, preds_empathy) = self.evaluate(dataloader=self.dev_loader, load_model=False)

            pearson_r_distress = pearsonr(true_distress, preds_distress)
            print(f'Pearson r (distress): {pearson_r_distress}')
            
            pearson_r_empathy = pearsonr(true_empathy, preds_empathy)
            print(f'Pearson r (empathy): {pearson_r_empathy}')
            
            val_loss_empathy = self.loss_fn(torch.tensor(preds_empathy), torch.tensor(true_empathy))
            print('Validation loss (empathy):', val_loss_empathy.item())
            
            if self.early_stopper.early_stop(val_loss_empathy):
                break

            if (pearson_r_empathy > self.best_pearson_r):
                self.best_pearson_r = pearson_r_empathy            
                if save_model:
                    torch.save(self.model.state_dict(), 'EmpathGuRo.pth')
                    print("Saved the model in epoch " + str(epoch+1))
            
            print(f'Best dev set Pearson r (empathy): {self.best_pearson_r}\n')

    def evaluate(self, dataloader, load_model=False):
        if load_model:
            self.model.load_state_dict(torch.load('EmpathGuRo.pth'))
    
        pred_distress = torch.empty((len(dataloader.dataset), 1), device=self.device) # len(self.dev_loader.dataset) --> # of samples
        pred_empathy = torch.empty((len(dataloader.dataset), 1), device=self.device) # len(self.dev_loader.dataset) --> # of samples
        self.model_distress.eval()
        self.model_empathy.eval()
    
        with torch.no_grad():
            idx = 0
            for data in dataloader:
                input_ids = data['input_ids'].to(self.device, dtype=torch.long)
                attention_mask = data['attention_mask'].to(self.device, dtype=torch.long)
                gender = data['gender'].to(self.device, dtype=torch.float).view(-1, 1)
                education = data['education'].to(self.device, dtype=torch.float).view(-1, 1)
                race = data['race'].to(self.device, dtype=torch.float).view(-1, 1)
                age = data['age'].to(self.device, dtype=torch.float).view(-1, 1)
                income = data['income'].to(self.device, dtype=torch.float).view(-1, 1)
                # self.prev_empathy = self.prev_empathy.to(self.device, dtype=torch.float).view(-1, 1)
        
                outputs_distress = self.model_distress(
                    input_ids=input_ids,                 
                    attention_mask=attention_mask,
                    gender=gender,
                    education=education,
                    race=race,
                    age=age,
                    income=income,
                    # empathy=self.prev_empathy
                )
                
                batch_size = outputs_distress.shape[0]
                pred_distress[idx:idx+batch_size, :] = outputs_distress
                
                outputs_empathy = self.model_empathy(
                    input_ids=input_ids,                 
                    attention_mask=attention_mask,
                    gender=gender,
                    education=education,
                    race=race,
                    age=age,
                    income=income,
                    distress=outputs_distress
                )
        
                pred_empathy[idx:idx+batch_size, :] = outputs_empathy
                
                # self.prev_empathy = outputs_empathy
                
                idx += batch_size
            
        return ([float(k) for k in pred_distress], [float(k) for k in pred_empathy])

## Epoch-level

In [14]:
class Trainer:
    def __init__(self, task, model_distress, model_empathy, lr, n_epochs, train_loader,
                 dev_loader, dev_label_file, device_id=0):
        self.device = get_device(device_id)
        self.task = task
        self.model_distress = model_distress.to(self.device)
        self.model_empathy = model_empathy.to(self.device)
        self.lr = lr
        self.n_epochs = n_epochs
        self.train_loader = train_loader
        self.dev_loader = dev_loader
        self.dev_label_file = dev_label_file
        
        self.loss_fn = nn.MSELoss()
        self.optimiser_distress = torch.optim.AdamW(
            params=self.model_distress.parameters(),
            lr=self.lr,
            betas=(0.9, 0.98),
            eps=1e-06,
            weight_decay=0.1
        )
        
        self.optimiser_empathy = torch.optim.AdamW(
            params=self.model_empathy.parameters(),
            lr=self.lr,
            betas=(0.9, 0.98),
            eps=1e-06,
            weight_decay=0.1
        )

        n_training_step = self.n_epochs*len(self.train_loader)
        
        self.lr_scheduler_distress = get_linear_schedule_with_warmup(
            optimizer=self.optimiser_distress,
            num_warmup_steps=0.06*n_training_step,
            num_training_steps=n_training_step
        )
        
        self.lr_scheduler_empathy = get_linear_schedule_with_warmup(
            optimizer=self.optimiser_empathy,
            num_warmup_steps=0.06*n_training_step,
            num_training_steps=n_training_step
        )
        
        self.best_pearson_r = -1.0 # initiliasation
        self.early_stopper = EarlyStopper(patience=3, min_delta=0.01)
        
        assert len(self.task) == 2, 'task must be a list with two elements'
        assert self.task[0] == 'distress', 'First item of task list should be the first guide - distress'
    
    def _freeze_unfreeze(self, model, freeze=False):
        for param in model.parameters():
            param.requires_grad = not freeze # if freeze is required (True): requires_grad is False
            
    def _training_step_distress(self):
        tr_loss_distress = 0.0
        idx = 0
        guide = torch.empty((len(self.train_loader.dataset), 1), device=self.device)
        
        self.model_distress.train()
    
        for data in self.train_loader:
            input_ids = data['input_ids'].to(self.device, dtype=torch.long)
            attention_mask = data['attention_mask'].to(self.device, dtype=torch.long)
            
            distress = data[self.task[0]].to(self.device, dtype=torch.float).view(-1, 1)
            
            gender = data['gender'].to(self.device, dtype=torch.float).view(-1, 1)
            education = data['education'].to(self.device, dtype=torch.float).view(-1, 1)
            race = data['race'].to(self.device, dtype=torch.float).view(-1, 1)
            age = data['age'].to(self.device, dtype=torch.float).view(-1, 1)
            income = data['income'].to(self.device, dtype=torch.float).view(-1, 1)
    
            ### Training distress model
            self._freeze_unfreeze(model_distress, freeze=False)
            self._freeze_unfreeze(model_empathy, freeze=True)
            
            outputs_distress = self.model_distress(
                input_ids=input_ids,                 
                attention_mask=attention_mask,
                gender=gender,
                education=education,
                race=race,
                age=age,
                income=income
            )
            loss = self.loss_fn(outputs_distress, distress)
            tr_loss_distress += loss.item()

            self.optimiser_distress.zero_grad()
            loss.backward()
            self.optimiser_distress.step()
            self.lr_scheduler_distress.step()

            batch_size = outputs_distress.shape[0]
            guide[idx:idx+batch_size, :] = outputs_distress
            idx += batch_size
            
        print(f'Train loss (distress): {tr_loss_distress / len(train_loader)}')
        return guide.detach()

    def _training_step_empathy(self, guide):
        tr_loss_empathy = 0.0
        idx = 0

        self.model_empathy.train()
    
        for data in self.train_loader:
            input_ids = data['input_ids'].to(self.device, dtype=torch.long)
            attention_mask = data['attention_mask'].to(self.device, dtype=torch.long)
            
            empathy = data[self.task[1]].to(self.device, dtype=torch.float).view(-1, 1)
            
            gender = data['gender'].to(self.device, dtype=torch.float).view(-1, 1)
            education = data['education'].to(self.device, dtype=torch.float).view(-1, 1)
            race = data['race'].to(self.device, dtype=torch.float).view(-1, 1)
            age = data['age'].to(self.device, dtype=torch.float).view(-1, 1)
            income = data['income'].to(self.device, dtype=torch.float).view(-1, 1)

            batch_size = empathy.shape[0]
            batched_guide = guide[idx:idx+batch_size, :]
            idx += batch_size
            
            ### Training empathy model
            self._freeze_unfreeze(model_distress, freeze=True)
            self._freeze_unfreeze(model_empathy, freeze=False)
            
            outputs_empathy = self.model_empathy(
                input_ids=input_ids,                 
                attention_mask=attention_mask,
                gender=gender,
                education=education,
                race=race,
                age=age,
                income=income,
                distress=batched_guide
            )
            loss = self.loss_fn(outputs_empathy, empathy)
            tr_loss_empathy += loss.item()
    
            self.optimiser_empathy.zero_grad()
            loss.backward()
            self.optimiser_empathy.step()
            self.lr_scheduler_empathy.step()
            
        print(f'Train loss (empathy): {tr_loss_empathy / len(train_loader)}')

    def fit(self, save_model=False):
        dev_label = pd.read_csv(self.dev_label_file, sep='\t', header=None)
        true_distress = dev_label.iloc[:, 1].tolist()
        true_empathy = dev_label.iloc[:, 0].tolist()
        
        for epoch in range(self.n_epochs):
            print(f'Epoch: {epoch+1}')
            guide = self._training_step_distress()
            self._training_step_empathy(guide)

            (preds_distress, preds_empathy) = self.evaluate(dataloader=self.dev_loader, load_model=False)

            pearson_r_distress = pearsonr(true_distress, preds_distress)
            print(f'Pearson r (distress): {pearson_r_distress}')
            
            pearson_r_empathy = pearsonr(true_empathy, preds_empathy)
            print(f'Pearson r (empathy): {pearson_r_empathy}')
            
            val_loss_empathy = self.loss_fn(torch.tensor(preds_empathy), torch.tensor(true_empathy))
            print('Validation loss (empathy):', val_loss_empathy.item())
            
            if self.early_stopper.early_stop(val_loss_empathy):
                break

            if (pearson_r_empathy > self.best_pearson_r):
                self.best_pearson_r = pearson_r_empathy            
                if save_model:
                    torch.save(self.model.state_dict(), 'EmpathGuRo.pth')
                    print("Saved the model in epoch " + str(epoch+1))
            
            print(f'Best dev set Pearson r (empathy): {self.best_pearson_r}\n')

    def evaluate(self, dataloader, load_model=False):
        if load_model:
            self.model.load_state_dict(torch.load('EmpathGuRo.pth'))
    
        pred_distress = torch.empty((len(dataloader.dataset), 1), device=self.device) # len(self.dev_loader.dataset) --> # of samples
        pred_empathy = torch.empty((len(dataloader.dataset), 1), device=self.device) # len(self.dev_loader.dataset) --> # of samples
        self.model_distress.eval()
        self.model_empathy.eval()
    
        with torch.no_grad():
            idx = 0
            for data in dataloader:
                input_ids = data['input_ids'].to(self.device, dtype=torch.long)
                attention_mask = data['attention_mask'].to(self.device, dtype=torch.long)
                gender = data['gender'].to(self.device, dtype=torch.float).view(-1, 1)
                education = data['education'].to(self.device, dtype=torch.float).view(-1, 1)
                race = data['race'].to(self.device, dtype=torch.float).view(-1, 1)
                age = data['age'].to(self.device, dtype=torch.float).view(-1, 1)
                income = data['income'].to(self.device, dtype=torch.float).view(-1, 1)
                # self.prev_empathy = self.prev_empathy.to(self.device, dtype=torch.float).view(-1, 1)
        
                outputs_distress = self.model_distress(
                    input_ids=input_ids,                 
                    attention_mask=attention_mask,
                    gender=gender,
                    education=education,
                    race=race,
                    age=age,
                    income=income,
                    # empathy=self.prev_empathy
                )
                
                batch_size = outputs_distress.shape[0]
                pred_distress[idx:idx+batch_size, :] = outputs_distress
                
                outputs_empathy = self.model_empathy(
                    input_ids=input_ids,                 
                    attention_mask=attention_mask,
                    gender=gender,
                    education=education,
                    race=race,
                    age=age,
                    income=income,
                    distress=outputs_distress
                )
        
                pred_empathy[idx:idx+batch_size, :] = outputs_empathy
                
                # self.prev_empathy = outputs_empathy
                
                idx += batch_size
            
        return ([float(k) for k in pred_distress], [float(k) for k in pred_empathy])