<a href="https://colab.research.google.com/github/ncerutti/colabs/blob/main/FSPD_MLP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install transformers torch wandb tqdm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.28.1-py3-none-any.whl (7.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m40.2 MB/s[0m eta [36m0:00:00[0m
Collecting wandb
  Downloading wandb-0.14.2-py3-none-any.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m43.2 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.11.0
  Downloading huggingface_hub-0.13.4-py3-none-any.whl (200 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m200.1/200.1 kB[0m [31m14.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m29.3 MB/s[0m eta [36m0:00:00[0m
Collecting pathto

In [2]:
from google.colab import drive
import pandas as pd
import numpy as np
import os
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from sklearn.utils.class_weight import compute_sample_weight
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from tqdm.notebook import tqdm
import wandb

torch.__version__

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("All good")
    torch.cuda.empty_cache()
else:
    device = torch.device("cpu")
    print("No GPU!!!")

No GPU!!!


In [3]:
drive.mount('/content/drive')
checkpoint_dir = "/content/drive/MyDrive/FSPD/ModelCheckpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

Mounted at /content/drive


In [4]:
# epoch_to_load = 9  # Change this to the epoch number you want to load
# checkpoint_path = os.path.join(checkpoint_dir, f"epoch_{epoch_to_load}_classifier.pth")

# loaded_classifier = SimpleNNClassifier(input_size, hidden_size, output_size).to(device)
# loaded_classifier.load_state_dict(torch.load(checkpoint_path))
# loaded_classifier.eval()  # Set the model to evaluation mode

In [5]:
wandb.login()
wandb.init(project="FSPD", config={"architecture": "SimpleNNClassifier", "epochs": 10, "batch_size": 16, "learning_rate": "OCP", "hidden_size": 512})

[34m[1mwandb[0m: Currently logged in as: [33mncerutti[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
# Create clean_fspd function. This function will take in the fspd dataframe and return a cleaned version of it.

def clean_fspd(fspd_f):
    """This function takes in the fspd dataframe and returns a cleaned version of it.
    """
    # Create a list of columns that are not needed
    drop_cols = ["lever", "itype", "source1link", "framework", "iso", "region_wb", "income_group2", "defn", "initialdate", "inclusion", "envitarget", "diethealth"]
    # Drop the columns in drop_cols from fspd_f
    fspd_f = fspd_f.drop(columns=drop_cols)

    # Replace the values in "covid_mentioned" with 0 if they are "nan"
    fspd_f["covid_mentioned"] = fspd_f["covid_mentioned"].replace(np.nan, 0)

    # Replace the values in "targeted" with 0 if the are "" o "N" and with 1 if they are "Y"
    fspd_f["targeted"] = fspd_f["targeted"].replace("", 0)
    fspd_f["targeted"] = fspd_f["targeted"].replace("N", 0)
    fspd_f["targeted"] = fspd_f["targeted"].replace("Y", 1)

    # replace "policy_code" with 0 if it is empty
    fspd_f["policy_code"] = fspd_f["policy_code"].replace(np.nan, 0)
    fspd_f["y_end"] = fspd_f["y_end"].replace(np.nan, 0)
    fspd_f["y_start"] = fspd_f["y_start"].replace("", 0)
    
    return fspd_f



def encode_fspd(fspd_f):
    """This function takes in the fspd dataframe and returns the dataframe with one-hot encoding of a list of variables.
    """
    to_encode = ["country", "db", "policy_code", "y_start", "y_end", "income_group", "fsd_group"]
    fspd_f = pd.get_dummies(fspd_f, columns=to_encode)
    return fspd_f


def get_non_text_features(batch_data, non_text_features):
    batch_indices = batch_data["index"].numpy()
    batch_non_text_features = non_text_features.loc[batch_indices]
    batch_non_text_features_tensor = torch.tensor(batch_non_text_features.values, dtype=torch.float32)
    return batch_non_text_features_tensor


def compute_class_weights(y):
    unique_classes = np.unique(y)
    class_weights = compute_class_weight('balanced', classes=unique_classes, y=y)
    return dict(zip(unique_classes, class_weights))


def get_sample_weights(y, class_weights):
    return np.array([class_weights[cls] for cls in y])

In [5]:
class FSPData(Dataset):
    def __init__(self, data, target_segment):
        self.data = data
        #self.target_lever = target_lever
        self.target_segment = target_segment

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data.iloc[idx]
        return {
            "index": torch.tensor(idx, dtype=torch.long),  # Add this line
            "policydecision_details": item["policydecision_details_tokens"],
            "policy_description": item["policy_description_tokens"],
            "contextoradditionalinfo": item["contextoradditionalinfo_tokens"],
            "source1name": item["source1name_tokens"],
            # Include other features as needed
            #"lever": torch.tensor(self.target_lever[idx], dtype=torch.long),
            "segment": torch.tensor(self.target_segment[idx], dtype=torch.long)
        }


In [6]:
class SimpleNNClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

In [7]:
# Read in data from .dta
#fspd_f = pd.read_stata("/content/FSPD.dta", index_col="id")
fspd_f = pd.read_stata("/content/drive/MyDrive/FSPD/FSPD.dta", index_col="id")

In [15]:
fspd = clean_fspd(fspd_f)
encfspd = encode_fspd(fspd)

In [16]:
from transformers import DistilBertTokenizer, DistilBertModel 

# Initialize DistilBERT model and tokenizer
pretrained_model_name = "distilbert-base-uncased"
tokenizer = DistilBertTokenizer.from_pretrained(pretrained_model_name)
bert_model = DistilBertModel.from_pretrained(pretrained_model_name)

bert_model = bert_model.to(device)

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [17]:
#from transformers import BertTokenizer, BertModel

## Initialize BERT model and tokenizer
#pretrained_model_name = "bert-base-uncased"
#tokenizer = BertTokenizer.from_pretrained(pretrained_model_name)
#bert_model = BertModel.from_pretrained(pretrained_model_name)

In [18]:
encfspd["policydecision_details_tokens"] = encfspd["policydecision_details"].apply(lambda x: tokenizer.encode(x, truncation=True, max_length=128))
encfspd["policy_description_tokens"] = encfspd["policy_description"].apply(lambda x: tokenizer.encode(x, truncation=True, max_length=128))
encfspd["contextoradditionalinfo_tokens"] = encfspd["contextoradditionalinfo"].apply(lambda x: tokenizer.encode(x, truncation=True, max_length=96))
encfspd["source1name_tokens"] = encfspd["source1name"].apply(lambda x: tokenizer.encode(x, truncation=True, max_length=8))

In [19]:
max_length = max(encfspd[["policydecision_details_tokens", "policy_description_tokens", "contextoradditionalinfo_tokens", "source1name"]].applymap(len).max())

encfspd["policydecision_details_tokens"] = encfspd["policydecision_details_tokens"].apply(lambda x: x + [0] * (max_length - len(x)))
encfspd["policy_description_tokens"] = encfspd["policy_description_tokens"].apply(lambda x: x + [0] * (max_length - len(x)))
encfspd["contextoradditionalinfo_tokens"] = encfspd["contextoradditionalinfo_tokens"].apply(lambda x: x + [0] * (max_length - len(x)))
encfspd["source1name_tokens"] = encfspd["source1name_tokens"].apply(lambda x: x + [0] * (max_length - len(x)))

In [20]:
encfspd["policydecision_details_tokens"] = encfspd["policydecision_details_tokens"].apply(lambda x: torch.tensor(x))
encfspd["policy_description_tokens"] = encfspd["policy_description_tokens"].apply(lambda x: torch.tensor(x))
encfspd["contextoradditionalinfo_tokens"] = encfspd["contextoradditionalinfo_tokens"].apply(lambda x: torch.tensor(x))
encfspd["source1name_tokens"] = encfspd["source1name_tokens"].apply(lambda x: torch.tensor(x))

In [21]:
# Create LabelEncoder instances for lever and segment
#lever_encoder = LabelEncoder()
segment_encoder = LabelEncoder()

# Fit the encoders on the respective target labels and transform them
#encfspd["lever"] = lever_encoder.fit_transform(encfspd["lever"])
encfspd["segment"] = segment_encoder.fit_transform(encfspd["segment"])

# Extract lever and segment labels from the encfspd DataFrame
#lever_labels = encfspd["lever"].values
segment_labels = encfspd["segment"].values

In [22]:
np.unique(segment_labels)

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])

In [23]:
encfspd["segment"]

id
1.0         1
2.0        10
3.0         8
4.0        10
5.0        10
           ..
15588.0     3
15589.0     1
15590.0    10
15591.0     1
15592.0     4
Name: segment, Length: 15592, dtype: int64

In [14]:
train_data

Unnamed: 0_level_0,segment,policydecision_details,policy_description,covid_mentioned,source1name,contextoradditionalinfo,targeted,is_eu_country,is_eu_wide,country_Afghanistan,...,income_group_High income,income_group_Low income,income_group_Lower middle income,income_group_Upper middle income,fsd_group_,fsd_group_Emerging and diversifying,fsd_group_Industrialized and consolidated,fsd_group_Informal and expanding,fsd_group_Modernizing and formalizing,fsd_group_Rural and traditional
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
5461.0,10,The national redd+ strategy (nrs) contributes ...,National redd+ strategy draft (nrs 2016-2030),0.0,"Ministry of environment, forest and climate ch...",,0,0.0,0.0,0,...,0,1,0,0,0,0,0,0,0,1
2894.0,3,Se anuncia la apertura de los mercados de chin...,"Sanitary, phytosanitary and technical standard...",0.0,Ministerio de agricultura,,0,0.0,0.0,0,...,1,0,0,0,0,0,0,0,1,0
11463.0,3,Covid-19: the european union has adopted tempo...,Technical barriers to trade,1.0,Official journal of the european union,Information on this policy has been provided b...,1,1.0,1.0,0,...,1,0,0,0,0,0,0,0,1,0
15336.0,3,The council of ministers of yemen banned impo...,Import ban,0.0,Ministry of trade and industry,,0,0.0,0.0,0,...,0,1,0,0,0,0,0,0,0,1
8349.0,3,Covid-19: the european commission (ec) elimina...,Import tariff,1.0,Usda report,The reintroduction of duties in april 2020 was...,0,1.0,1.0,0,...,1,0,0,0,0,0,0,0,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5192.0,5,Egypt has set a fixed price for unsubsidised b...,Price control,0.0,Reuters,,0,0.0,0.0,0,...,0,0,1,0,0,0,0,1,0,0
13419.0,3,Despite the export ban measures introduced in ...,Other export promotion measures,0.0,Fapda country brief,,0,0.0,0.0,0,...,0,1,0,0,1,0,0,0,0,0
5391.0,10,The united nations development assistance fram...,United nations development assistance framewor...,0.0,Ministry of finance and economic cooperation,,0,0.0,0.0,0,...,0,1,0,0,0,0,0,0,0,1
861.0,9,The government of australia issued the conserv...,Unspecified land policy measure,0.0,Chief parliamentary counsel,,0,0.0,0.0,0,...,1,0,0,0,0,0,1,0,0,0


In [24]:
## Careful: no test!
# train_data, val_data, train_segment, val_segment = train_test_split(encfspd, segment_labels, test_size=0.2, random_state=42)

# With test
train_data, temp_data, train_segment, temp_segment = train_test_split(encfspd, segment_labels, test_size=0.3, random_state=42)
val_data, test_data, val_segment, test_segment = train_test_split(temp_data, temp_segment, test_size=0.5, random_state=42)

# Compute class weights
class_weights = compute_class_weights(train_segment)

# Compute sample weights
train_sample_weights = get_sample_weights(train_segment, class_weights)

# Create WeightedRandomSampler
weighted_sampler = WeightedRandomSampler(train_sample_weights, num_samples=len(train_sample_weights), replacement=True)

train_dataset = FSPData(train_data, train_segment)
val_dataset = FSPData(val_data, val_segment)
test_dataset = FSPData(test_data, test_segment)

## Final: train on whole dataset
# train_dataset = FSPData(encfspd, segment_labels)

batch_size = 16
#train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=weighted_sampler)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [19]:
# Create non-text-feature dataframe. It contains column 6 and then from 8 to the end

slice1 = encfspd.iloc[:, 3]
slice2 = encfspd.iloc[:, 6:]
slice3 = encfspd.iloc[:, 10:411]

non_text_features = pd.concat([slice1, slice2, slice3], axis=1).reset_index(drop=True)

In [20]:
print(non_text_features.select_dtypes(include=['object']).columns)


Index(['policydecision_details_tokens', 'policy_description_tokens',
       'contextoradditionalinfo_tokens', 'source1name_tokens'],
      dtype='object')


In [21]:
non_text_features = non_text_features.drop(non_text_features.select_dtypes(include=['object']).columns, axis=1)


In [22]:
print(non_text_features.select_dtypes(include=['object']).columns)

Index([], dtype='object')


In [23]:
print(len(segment_labels))

15592


In [24]:
np.unique(segment_labels)

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])

In [25]:
# Classifier N.1: Simple NN

input_size = 3879
hidden_size = 512 
output_size = len(np.unique(segment_labels))

classifier = SimpleNNClassifier(input_size, hidden_size, output_size).to(device)
criterion = nn.CrossEntropyLoss()
#optimizer = optim.Adam(classifier.parameters(), lr=0.001)

runs = 4
epochs_per_run = 5
epochs = runs * epochs_per_run
learning_rate = 0.01

optimizer = optim.AdamW(classifier.parameters(), lr=learning_rate)
scheduler = OneCycleLR(optimizer, max_lr=learning_rate, epochs=epochs, steps_per_epoch=len(train_loader))

In [None]:
for epoch in range(epochs):
    current_run = epoch // epochs_per_run + 1
    current_epoch = epoch % epochs_per_run + 1
    train_running_loss = 0.0
    train_running_corrects = torch.tensor(0, device=device, dtype=torch.float)
    val_running_loss = 0.0
    val_running_corrects = torch.tensor(0, device=device, dtype=torch.float)

    # Training loop
    classifier.train()
    for batch_idx, batch_data in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Training]")):
        
        batch_data["policydecision_details"] = batch_data["policydecision_details"].to(device)
        batch_data["policy_description"] = batch_data["policy_description"].to(device)
        batch_data["contextoradditionalinfo"] = batch_data["contextoradditionalinfo"].to(device)
        batch_data["source1name"] = batch_data["source1name"].to(device)

        # Obtain embeddings for each text feature
        policydecision_details_embeddings = bert_model(batch_data["policydecision_details"])
        policy_description_embeddings = bert_model(batch_data["policy_description"])
        contextoradditionalinfo_embeddings = bert_model(batch_data["contextoradditionalinfo"])
        source1name_embeddings = bert_model(batch_data["source1name"])

        # Concatenate embeddings
        combined_embeddings = torch.cat((policydecision_details_embeddings.last_hidden_state[:, 0, :],
                                        policy_description_embeddings.last_hidden_state[:, 0, :],
                                        contextoradditionalinfo_embeddings.last_hidden_state[:, 0, :],
                                        source1name_embeddings.last_hidden_state[:, 0, :]), dim=1)
        
        # Concatenate non-text features
        batch_non_text_features = get_non_text_features(batch_data, non_text_features)
        batch_non_text_features = batch_non_text_features.to(device)
        combined_features = torch.cat((combined_embeddings, batch_non_text_features), dim=1)

        # Make sure to zero the gradients before every training step
        optimizer.zero_grad()
        
        # Forward pass through the classifier
        logits = classifier(combined_features)
        
        # Calculate the loss
        loss = criterion(logits, batch_data["segment"].to(device))
        
        # Backward pass
        loss.backward()
        
        # Update the weights
        optimizer.step()
        # Accumulate the loss and correct predictions for the current batch
        train_running_loss += loss.item() * batch_data["policydecision_details"].size(0)
        _, preds = torch.max(logits, 1)
        train_running_corrects += torch.sum(preds == batch_data["segment"].to(device))

    scheduler.step()
    train_epoch_loss = train_running_loss / len(train_loader.dataset)
    train_epoch_acc = train_running_corrects.double() / len(train_loader.dataset)

    # Validation loop
    classifier.eval()
    with torch.no_grad():
        for batch_idx, batch_data in enumerate(tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Validation]")):

            batch_data["policydecision_details"] = batch_data["policydecision_details"].to(device)
            batch_data["policy_description"] = batch_data["policy_description"].to(device)
            batch_data["contextoradditionalinfo"] = batch_data["contextoradditionalinfo"].to(device)
            batch_data["source1name"] = batch_data["source1name"].to(device)

            # Obtain embeddings for each text feature
            policydecision_details_embeddings = bert_model(batch_data["policydecision_details"])
            policy_description_embeddings = bert_model(batch_data["policy_description"])
            contextoradditionalinfo_embeddings = bert_model(batch_data["contextoradditionalinfo"])
            source1name_embeddings = bert_model(batch_data["source1name"])

            # Concatenate embeddings
            combined_embeddings = torch.cat((policydecision_details_embeddings.last_hidden_state[:, 0, :],
                                            policy_description_embeddings.last_hidden_state[:, 0, :],
                                            contextoradditionalinfo_embeddings.last_hidden_state[:, 0, :],
                                            source1name_embeddings.last_hidden_state[:, 0, :]), dim=1)
            
            # Concatenate non-text features
            batch_non_text_features = get_non_text_features(batch_data, non_text_features)
            batch_non_text_features = batch_non_text_features.to(device)
            combined_features = torch.cat((combined_embeddings, batch_non_text_features), dim=1)

            # Forward pass through the classifier
            logits = classifier(combined_features)

            # Calculate the loss
            loss = criterion(logits, batch_data["segment"].to(device))

            # Accumulate the loss and correct predictions for the current batch
            val_running_loss += loss.item() * batch_data["policydecision_details"].size(0)
            _, preds = torch.max(logits, 1)
            val_running_corrects += torch.sum(preds == batch_data["segment"].to(device))

    val_epoch_loss = val_running_loss / len(val_loader.dataset)
    val_epoch_acc = val_running_corrects.double() / len(val_loader.dataset)

    # Log metrics to wandb
    wandb.log({"train_loss": train_epoch_loss, "train_acc": train_epoch_acc,
               "val_loss": val_epoch_loss, "val_acc": val_epoch_acc})

    # Print metrics to console
    print(f"Run {current_run}/{runs}, Epoch {current_epoch}/{epochs_per_run}")
    print(f"Train Loss: {train_epoch_loss:.4f}, Train Acc: {train_epoch_acc:.4f}")
    print(f"Val Loss: {val_epoch_loss:.4f}, Val Acc: {val_epoch_acc:.4f}")
    checkpoint_path = os.path.join(checkpoint_dir, f"epoch_{epoch+1}_classifier_redux.pth")
    torch.save(classifier.state_dict(), checkpoint_path)

wandb.finish()


Epoch 1/20 [Training]:   0%|          | 0/683 [00:00<?, ?it/s]

Epoch 1/20 [Validation]:   0%|          | 0/147 [00:00<?, ?it/s]

Run 1/4, Epoch 1/5
Train Loss: 1.6992, Train Acc: 0.4177
Val Loss: 1.5125, Val Acc: 0.4489


Epoch 2/20 [Training]:   0%|          | 0/683 [00:00<?, ?it/s]

Epoch 2/20 [Validation]:   0%|          | 0/147 [00:00<?, ?it/s]

Run 1/4, Epoch 2/5
Train Loss: 1.0275, Train Acc: 0.6891
Val Loss: 1.1618, Val Acc: 0.6127


Epoch 3/20 [Training]:   0%|          | 0/683 [00:00<?, ?it/s]

Epoch 3/20 [Validation]:   0%|          | 0/147 [00:00<?, ?it/s]

Run 1/4, Epoch 3/5
Train Loss: 0.6978, Train Acc: 0.7990
Val Loss: 0.9470, Val Acc: 0.6917


Epoch 4/20 [Training]:   0%|          | 0/683 [00:00<?, ?it/s]

Epoch 4/20 [Validation]:   0%|          | 0/147 [00:00<?, ?it/s]

Run 1/4, Epoch 4/5
Train Loss: 0.4886, Train Acc: 0.8583
Val Loss: 0.9015, Val Acc: 0.6870


Epoch 5/20 [Training]:   0%|          | 0/683 [00:00<?, ?it/s]

Epoch 5/20 [Validation]:   0%|          | 0/147 [00:00<?, ?it/s]

Run 1/4, Epoch 5/5
Train Loss: 0.3860, Train Acc: 0.8838
Val Loss: 0.7257, Val Acc: 0.7332


Epoch 6/20 [Training]:   0%|          | 0/683 [00:00<?, ?it/s]

Epoch 6/20 [Validation]:   0%|          | 0/147 [00:00<?, ?it/s]

Run 2/4, Epoch 1/5
Train Loss: 0.3204, Train Acc: 0.9044
Val Loss: 0.5239, Val Acc: 0.8286


Epoch 7/20 [Training]:   0%|          | 0/683 [00:00<?, ?it/s]

Epoch 7/20 [Validation]:   0%|          | 0/147 [00:00<?, ?it/s]

Run 2/4, Epoch 2/5
Train Loss: 0.2759, Train Acc: 0.9142
Val Loss: 0.4400, Val Acc: 0.8611


Epoch 8/20 [Training]:   0%|          | 0/683 [00:00<?, ?it/s]

Epoch 8/20 [Validation]:   0%|          | 0/147 [00:00<?, ?it/s]

Run 2/4, Epoch 3/5
Train Loss: 0.2489, Train Acc: 0.9216
Val Loss: 0.5548, Val Acc: 0.7995


Epoch 9/20 [Training]:   0%|          | 0/683 [00:00<?, ?it/s]

Epoch 9/20 [Validation]:   0%|          | 0/147 [00:00<?, ?it/s]

Run 2/4, Epoch 4/5
Train Loss: 0.2186, Train Acc: 0.9290
Val Loss: 0.4366, Val Acc: 0.8542


Epoch 10/20 [Training]:   0%|          | 0/683 [00:00<?, ?it/s]

Epoch 10/20 [Validation]:   0%|          | 0/147 [00:00<?, ?it/s]

Run 2/4, Epoch 5/5
Train Loss: 0.1992, Train Acc: 0.9372
Val Loss: 0.4278, Val Acc: 0.8615


Epoch 11/20 [Training]:   0%|          | 0/683 [00:00<?, ?it/s]

Epoch 11/20 [Validation]:   0%|          | 0/147 [00:00<?, ?it/s]

Run 3/4, Epoch 1/5
Train Loss: 0.1793, Train Acc: 0.9437
Val Loss: 0.4439, Val Acc: 0.8525


Epoch 12/20 [Training]:   0%|          | 0/683 [00:00<?, ?it/s]

Epoch 12/20 [Validation]:   0%|          | 0/147 [00:00<?, ?it/s]

Run 3/4, Epoch 2/5
Train Loss: 0.1627, Train Acc: 0.9469
Val Loss: 0.4245, Val Acc: 0.8593


Epoch 13/20 [Training]:   0%|          | 0/683 [00:00<?, ?it/s]

Epoch 13/20 [Validation]:   0%|          | 0/147 [00:00<?, ?it/s]

Run 3/4, Epoch 3/5
Train Loss: 0.1481, Train Acc: 0.9525
Val Loss: 0.4054, Val Acc: 0.8726


Epoch 14/20 [Training]:   0%|          | 0/683 [00:00<?, ?it/s]

Epoch 14/20 [Validation]:   0%|          | 0/147 [00:00<?, ?it/s]

Run 3/4, Epoch 4/5
Train Loss: 0.1477, Train Acc: 0.9530
Val Loss: 0.3549, Val Acc: 0.8918


Epoch 15/20 [Training]:   0%|          | 0/683 [00:00<?, ?it/s]

Epoch 15/20 [Validation]:   0%|          | 0/147 [00:00<?, ?it/s]

Run 3/4, Epoch 5/5
Train Loss: 0.1530, Train Acc: 0.9502
Val Loss: 0.3256, Val Acc: 0.8948


Epoch 16/20 [Training]:   0%|          | 0/683 [00:00<?, ?it/s]

Epoch 16/20 [Validation]:   0%|          | 0/147 [00:00<?, ?it/s]

Run 4/4, Epoch 1/5
Train Loss: 0.1376, Train Acc: 0.9575
Val Loss: 0.3615, Val Acc: 0.8794


Epoch 17/20 [Training]:   0%|          | 0/683 [00:00<?, ?it/s]

Epoch 17/20 [Validation]:   0%|          | 0/147 [00:00<?, ?it/s]

Run 4/4, Epoch 2/5
Train Loss: 0.1161, Train Acc: 0.9621
Val Loss: 0.3828, Val Acc: 0.8816


Epoch 18/20 [Training]:   0%|          | 0/683 [00:00<?, ?it/s]

Epoch 18/20 [Validation]:   0%|          | 0/147 [00:00<?, ?it/s]

Run 4/4, Epoch 3/5
Train Loss: 0.1258, Train Acc: 0.9568
Val Loss: 0.3469, Val Acc: 0.8940


Epoch 19/20 [Training]:   0%|          | 0/683 [00:00<?, ?it/s]

In [None]:
## TEST

# Evaluate the model on the test dataset
classifier.eval()

test_running_loss = 0.0
test_running_corrects = torch.tensor(0, device=device, dtype=torch.float)

with torch.no_grad():
    for batch_idx, batch_data in enumerate(tqdm(test_loader, desc="Testing")):
        batch_data["policydecision_details"] = batch_data["policydecision_details"].to(device)
        batch_data["policy_description"] = batch_data["policy_description"].to(device)
        batch_data["contextoradditionalinfo"] = batch_data["contextoradditionalinfo"].to(device)
        batch_data["source1name"] = batch_data["source1name"].to(device)

        # Obtain embeddings for each text feature
        policydecision_details_embeddings = bert_model(batch_data["policydecision_details"])
        policy_description_embeddings = bert_model(batch_data["policy_description"])
        contextoradditionalinfo_embeddings = bert_model(batch_data["contextoradditionalinfo"])
        source1name_embeddings = bert_model(batch_data["source1name"])

        # Concatenate embeddings
        combined_embeddings = torch.cat((policydecision_details_embeddings.last_hidden_state[:, 0, :],
                                         policy_description_embeddings.last_hidden_state[:, 0, :],
                                         contextoradditionalinfo_embeddings.last_hidden_state[:, 0, :],
                                         source1name_embeddings.last_hidden_state[:, 0, :]), dim=1)

        # Concatenate non-text features
        batch_non_text_features = get_non_text_features(batch_data, non_text_features)
        batch_non_text_features = batch_non_text_features.to(device)
        combined_features = torch.cat((combined_embeddings, batch_non_text_features), dim=1)

        # Forward pass through the classifier
        logits = classifier(combined_features)

        # Calculate the loss
        loss = criterion(logits, batch_data["segment"].to(device))

        # Accumulate the loss and correct predictions for the current batch
        test_running_loss += loss.item() * batch_data["policydecision_details"].size(0)
        _, preds = torch.max(logits, 1)
        test_running_corrects += torch.sum(preds == batch_data["segment"].to(device))

test_loss = test_running_loss / len(test_loader.dataset)
test_acc = test_running_corrects.double() / len(test_loader.dataset)

print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")
