<a href="https://colab.research.google.com/github/francescostreet/-AN2DL-2025-First-challenge/blob/main/challenge.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Set seed for reproducibility
SEED = 42

# Import necessary libraries
import os

# Set environment variables before importing modules
os.environ['PYTHONHASHSEED'] = str(SEED)
os.environ['MPLCONFIGDIR'] = os.getcwd() + '/configs/'

# Suppress warnings
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=Warning)

# Import necessary modules
import logging
import random
import numpy as np

# Set seeds for random number generators in NumPy and Python
np.random.seed(SEED)
random.seed(SEED)

# Import PyTorch
import torch
torch.manual_seed(SEED)
from torch import nn
# from torchsummary import summary
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import TensorDataset, DataLoader
logs_dir = "tensorboard"
!pkill -f tensorboard
%load_ext tensorboard
!mkdir -p models

if torch.cuda.is_available():
    device = torch.device("cuda")
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.benchmark = True
else:
    device = torch.device("cpu")

print(f"PyTorch version: {torch.__version__}")
print(f"Device: {device}")

# Import other libraries
import copy
import shutil
from itertools import product
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Configure plot display settings
sns.set(font_scale=1.4)
sns.set_style('white')
plt.rc('font', size=14)
%matplotlib inline

PyTorch version: 2.8.0+cu126
Device: cpu


In [2]:
# Load dataset

X_train = pd.read_csv('pirate_pain_train.csv')
y_train = pd.read_csv('pirate_pain_train_labels.csv')

X_test = pd.read_csv('pirate_pain_test.csv')

In [4]:
# Define column names for the dataset
column_names = ['sample_index', 'time', 'pain_survey_1', 'pain_survey_2', 'pain_survey_3', 'pain_survey_4', 'n_legs', 'n_hands', 'n_eyes', 'joint_00']

# Read the dataset into a DataFrame with specified column names
df = pd.read_csv('pirate_pain_train.csv', header=None, names=column_names)

# Remove rows with any missing values
df.dropna(axis=0, how='any', inplace=True)

# Print the shape of the DataFrame
print(f"DataFrame shape: {df.shape}")

# Display the first 10 rows of the DataFrame
df.head(10)

DataFrame shape: (105761, 10)


Unnamed: 0,Unnamed: 1,Unnamed: 2,Unnamed: 3,Unnamed: 4,Unnamed: 5,Unnamed: 6,Unnamed: 7,Unnamed: 8,Unnamed: 9,Unnamed: 10,Unnamed: 11,Unnamed: 12,Unnamed: 13,Unnamed: 14,Unnamed: 15,Unnamed: 16,Unnamed: 17,Unnamed: 18,Unnamed: 19,Unnamed: 20,Unnamed: 21,Unnamed: 22,Unnamed: 23,Unnamed: 24,Unnamed: 25,Unnamed: 26,Unnamed: 27,Unnamed: 28,Unnamed: 29,sample_index,time,pain_survey_1,pain_survey_2,pain_survey_3,pain_survey_4,n_legs,n_hands,n_eyes,joint_00
sample_index,time,pain_survey_1,pain_survey_2,pain_survey_3,pain_survey_4,n_legs,n_hands,n_eyes,joint_00,joint_01,joint_02,joint_03,joint_04,joint_05,joint_06,joint_07,joint_08,joint_09,joint_10,joint_11,joint_12,joint_13,joint_14,joint_15,joint_16,joint_17,joint_18,joint_19,joint_20,joint_21,joint_22,joint_23,joint_24,joint_25,joint_26,joint_27,joint_28,joint_29,joint_30
000,0,2,0,2,1,two,two,two,1.0947052308906111,0.9852806323892689,1.0183017685309896,1.0103846219973498,0.9717169667677581,1.0222626970207267,0.901754688096325,0.9996588076563115,0.7129890259094573,1.0501418311133923,0.5295547622945486,0.4473700071620395,1.091045860897634,0.0,3.0539757619950095e-07,4.7914302751208566e-06,4.402450698163696e-06,0.0,1.6849528758887604e-05,1.2103686199933038e-05,6.140248225873454e-07,3.499557792481375e-06,1.945042470132259e-06,3.999558416844207e-06,1.153299174805662e-05,3.8059302778858464e-06,0.017592085928338518,0.013507849980388193,0.026797650940792724,0.027814594224909676,0.5
000,1,2,2,2,2,two,two,two,1.135183103536129,1.0211747000619738,0.9943433725932543,1.052363882715683,0.9999441405465478,1.012395036980465,0.923340576193706,1.0358501537217646,0.7226847196091615,1.0603133408550205,0.44680991176983303,0.41443183501500275,1.0458617871650786,1.2325046641239027e-05,4.806347286868409e-06,0.0,3.8751338174573535e-06,8.706925123652644e-06,0.0,1.0194172618414565e-05,1.9319783908446927e-06,3.976952213431784e-07,6.7651074426812e-07,6.019626875860215e-06,4.64377443665608e-08,0.0,0.013352218672651914,0.0,0.013376576038862383,0.013715933906977116,0.5
000,2,2,0,2,2,two,two,two,1.0807448109145028,0.9628423568573223,1.0095875703922053,0.9771686446308178,0.9847399141416346,1.019929688699317,0.9765671074880842,1.0727511484853933,0.668043044142936,1.0114104124023555,0.43249864699822294,0.43153480854491194,1.0882210916846609,0.0,3.125992543870344e-07,0.0,1.035689911330332e-06,0.0,9.727752212477372e-06,2.0240181089146394e-06,5.730837816958964e-07,1.5338202764856223e-07,1.6985249698772954e-07,1.4460506893440146e-06,2.424536490939014e-06,2.5135187612327057e-06,0.01622542717601379,0.008110441391572916,0.02409691522589848,0.023105023010497494,0.5
000,3,2,2,2,2,two,two,two,0.9380165318378043,1.081592270878279,0.9980205532045711,0.9872828163155588,0.9241606400066662,1.0026415359849703,0.8309818845474468,1.0807549911684773,0.7020846376287088,1.047223091467062,0.4788060971456342,0.4206648391990636,1.0968324679878347,0.0,3.1618889619588626e-07,0.0,2.5045319889005086e-06,3.4564764188112654e-05,0.0,6.883125683618709e-06,3.677485882417406e-05,1.0068651685544181e-05,5.511079228636987e-07,1.8475966646879033e-06,5.432416362110678e-08,0.0,0.011831559452139423,0.00745007119889528,0.028613141314700982,0.02464820836318195,0.5
000,4,2,2,2,2,two,two,two,1.0901849482346806,1.0321446370205805,1.0087097025989937,0.963657611812324,1.01629057713775,1.0313006090643326,0.9560081721226316,0.9880225833649314,0.7121966006151591,1.0447308388816443,0.45290630069456983,0.4765371874884908,1.1039677783007522,0.0,3.6112449603849616e-07,1.8837532691671258e-06,1.0377896638281432e-06,0.0,2.9797576830971937e-05,2.034571423859658e-06,2.0379387724607624e-05,4.437265668725749e-06,1.7354585896225497e-07,1.5527219085486641e-06,5.8253658712538915e-08,7.044831533755251e-06,0.005360386839374858,0.002531546539790179,0.03302617453410233,0.02532797214906949,0.5
000,5,2,0,2,1,two,two,two,1.1460314612288625,1.0220919949337242,0.975502922491714,0.9703033927783636,1.0924268139421016,1.0267486812509379,0.9030129466841167,1.0195066405570703,0.6462276037756015,1.056162106431079,0.5470641611525502,0.46470935866864155,1.1433193525698642,0.0,3.2334535499303917e-07,1.2188906419334232e-06,1.0385433432397744e-06,0.0,3.263957907020099e-05,1.6006029591871748e-05,2.0607825512760706e-05,1.0731674090588247e-06,1.7538372278427127e-07,2.957340042075233e-07,6.217310659686344e-08,7.475357955435864e-06,0.006150383917344993,0.006444462796853394,0.033101101838941854,0.02376662479054096,0.5
000,6,2,1,2,1,two,two,two,1.0258698643268944,1.0798815430073259,0.9216514328354484,0.8785897648787432,0.9053466204335996,1.0337666275039783,0.9525038424730377,1.0750568609176125,0.7349061424656991,1.0425289374331115,0.4450745132712646,0.47446746051827665,1.06397274043496,3.5154634412833193e-06,3.269120004273811e-07,6.0407808548935105e-06,1.0390993542953717e-06,0.0,0.0,2.04402683720342e-06,2.0380297633275998e-05,1.074799677094579e-06,1.772156251347998e-07,1.9765583948385075e-06,1.5760863171953413e-06,4.637804261006082e-06,0.006495420711872876,0.006420783622156604,0.0318036627067577,0.01905554991156486,0.5
000,7,2,2,2,2,two,two,two,1.038597318187619,1.039254817772802,1.0535315489870498,1.0043352396294278,0.9982444073599172,1.0374798699279513,0.9121288394863325,1.0271858823456954,0.6366906387518414,0.9961253616683727,0.46864065836340035,0.4639635905232247,1.145227357087792,0.0,3.490337128696362e-06,2.037725911193479e-06,3.387811045806682e-06,0.0,3.31279997848281e-05,2.048341213067523e-06,4.7033048566893635e-07,8.82907419971825e-07,1.7904150374561284e-07,2.2105618044813743e-06,1.4857409468201832e-06,0.0,0.015997945858351518,0.0053974619479166975,0.035551569733665286,0.015731614426066978,0.5
000,8,2,2,0,1,two,two,two,0.9842507543998211,1.0056001343947591,0.9654872383182999,1.0444429972244689,0.9655178927836832,1.0470288272496147,0.9848886038610417,1.0285471241369835,0.7358754876592485,0.9862288690598963,0.45405516264728196,0.47852534383965195,1.1010662798328947,3.7632781675330198e-06,3.3402169587319703e-07,3.5154111082626364e-06,1.0396179856738023e-06,0.0,0.0,1.1277781647534198e-05,4.9842766695872264e-05,1.6210552284637883e-06,1.1651607968521788e-06,3.0301636894037034e-07,5.416678423467074e-07,0.0,0.020538582101834763,0.008516687431520533,0.008635017533525661,0.015257299098127262,0.5
