In [1]:
import numpy as np

from kaggle_movie_genres.config import load_config
from kaggle_movie_genres.labelhandler import LabelHandler
from kaggle_movie_genres.featurizer import create_tokenizer_and_embedder
from kaggle_movie_genres.dataloader import create_dataloader
from kaggle_movie_genres.cls_classifier import CLS_Classifier
from kaggle_movie_genres.submission import format_predictions
from kaggle_movie_genres.trainpredict import TrainPredict
import logging
logging.basicConfig(level=logging.INFO,    
                    format='%(asctime)s - %(levelname)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S'  # Date + short timestamp
)
logger = logging.getLogger(__name__)
import torch.nn as nn
import torch
import tqdm


### Helper stuff    

In [2]:
# config contains all constants pathes and settings
config = load_config()

# label_handler helps to convert labels between different formats
label_handler = LabelHandler(config)



### Tokenizer, embedder and the model

In [3]:
tokenizer, embedder = create_tokenizer_and_embedder(config)
model = CLS_Classifier(embedder, num_labels=label_handler.get_multi_hot_length(), config=config)

### Create train / validation sets

In [4]:
train_set, validation_set = create_dataloader('data/train.csv', tokenizer, label_handler, config, validation_split=True)
test_set,_ = create_dataloader('data/test.csv', tokenizer, label_handler, config, validation_split=False)

2025-11-16 09:51:34 - INFO - Loaded 8000 records from data/train.csv
2025-11-16 09:51:34 - INFO - Using max token length: 256
2025-11-16 09:51:34 - INFO - Loaded 2000 records from data/test.csv
2025-11-16 09:51:34 - INFO - Using max token length: 256


In [5]:
TRAIN_NAME = config['name']
trainer = TrainPredict(TRAIN_NAME,config, label_handler, model, train_set, validation_set, test_set)
trainer.train()

2025-11-16 09:51:37 - INFO - Starting epoch 1/50
2025-11-16 09:52:53 - INFO - Training epoch 1 completed. Train F1 Score: 0.3039, Train Loss: 0.2700
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
2025-11-16 09:53:11 - INFO - Validation epoch 1 completed. Val F1 Score: 0.4234, Val Loss: 0.2162
2025-11-16 09:53:35 - INFO - Test epoch 1 completed. 
2025-11-16 09:53:36 - INFO - Epoch 1 completed. Train Loss: 0.2700, Val Loss: 0.2162
2025-11-16 09:53:36 - INFO - Starting epoch 2/50
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
2025-11-16 09:54:46 - INFO - Training epoch 2 completed. Train F1 Score: 0.4523, Train Loss: 0.2266
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
2025-11-16 09:55:03 - INFO - Validation epoch 2 completed. Val F1 Score: 0.5033, Val Loss: 0.2082
2025-11-16 09:55:25 - INFO - Test epoch 2 completed. 
2025-11-16 09:55:25 - INFO - Epoch 2 completed. Train Loss: 0.2266, Val Loss: 0.20