In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import pickle

from src.custom_dataset import CustomDataset
from src.handler import Handler

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
images_data_path = './data/archive/images/images'
x_train_file_path = './data/x_train.csv'
y_train_file_path = './data/y_train.csv'
x_val_file_path = './data/x_val.csv'
y_val_file_path = './data/y_val.csv'

In [3]:
batch_size = 1
num_epochs = 1
checkpoint_interval = 200
validation_check_steps = 1

In [4]:
train_dataset = CustomDataset(x_path=x_train_file_path, y_path=y_train_file_path, image_folder_path=images_data_path)
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

val_dataset = CustomDataset(x_path=x_val_file_path, y_path=y_val_file_path, image_folder_path=images_data_path)
val_data_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

In [None]:
x_num_categories_list=train_dataset.get_x_num_categories_list()
y_num_categories_list=train_dataset.get_y_num_categories_list()
label_encoders = train_dataset.get_label_encoders()

[7, 7, 12, 9, 13, 33, 33, 7, 5, 5, 5]


In [6]:
with open('./data/label_encoders.pkl', 'wb') as f:
	pickle.dump(label_encoders, f)

with open('./data/x_num_categories_list.pkl', 'wb') as f:
	pickle.dump(x_num_categories_list, f)

with open('./data/y_num_categories_list.pkl', 'wb') as f:
	pickle.dump(y_num_categories_list, f)

In [7]:
initial_lr = 1e-3
min_lr = 1e-5
weight_decay_value = 1e-4

criterion = nn.CrossEntropyLoss()
model = Handler(x_num_categories_list=x_num_categories_list, y_num_categories_list=y_num_categories_list)
optimizer = optim.AdamW(model.parameters(), lr=initial_lr, weight_decay=weight_decay_value)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=min_lr)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

criterion = nn.CrossEntropyLoss()
num_tasks = len(y_num_categories_list)
best_accuracy = 0.0
best_weights = None

# Define intervals for printing and validation checks
print_interval = 1  # Print training stats every 10 batches
validation_check_steps = 100  # Validate every 100 batches

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct_predictions = [0] * num_tasks  # List to store correct predictions per task
    total_samples = 0
    
    # Training phase
    for batch_idx, (images, tabular_data, labels) in enumerate(train_data_loader):
        # Move data to GPU if available
        images = images.to(device)
        tabular_data = tabular_data.to(device)
        labels = labels.to(device)  # Shape: [batch_size, total_label_length]

        # Zero the gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs_basic = model(images, tabular_data)  # List of outputs per task

        # Partition labels into labels_list using y_num_categories_list
        labels_list = []
        start_idx = 0
        for num_categories in y_num_categories_list:
            end_idx = start_idx + num_categories
            # Extract one-hot labels for this task
            task_labels_one_hot = labels[:, start_idx:end_idx]
            # Convert one-hot labels to class indices
            task_labels_indices = torch.argmax(task_labels_one_hot, dim=1)
            labels_list.append(task_labels_indices)
            start_idx = end_idx
        
        # Compute loss per task
        losses = []
        for output, label in zip(outputs_basic, labels_list):
            loss = criterion(output, label)
            losses.append(loss)
        total_loss = sum(losses)

        # Backward pass and optimize
        total_loss.backward()
        optimizer.step()

        # Update running loss
        running_loss += total_loss.item()
        total_samples += labels.size(0)

        # Calculate accuracy per task
        for i, (output, label) in enumerate(zip(outputs_basic, labels_list)):
            _, predicted = torch.max(output, dim=1)
            correct = (predicted == label).sum().item()
            correct_predictions[i] += correct

        # Print training stats every print_interval batches
        if (batch_idx + 1) % print_interval == 0:
            batch_accuracy = [100 * correct_predictions[i] / total_samples for i in range(num_tasks)]
            avg_accuracy = sum(batch_accuracy) / num_tasks
            print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(train_data_loader)}], "
                  f"Loss: {total_loss.item():.4f}, Average Batch Accuracy: {avg_accuracy:.2f}%")

        # Validation evaluation every validation_check_steps steps
        if (batch_idx + 1) % validation_check_steps == 0:
            model.eval()
            val_loss = 0.0
            val_correct_predictions = [0] * num_tasks
            val_total_samples = 0
            with torch.no_grad():
                for val_images, val_tabular_data, val_labels in val_data_loader:
                    val_images = val_images.to(device)
                    val_tabular_data = val_tabular_data.to(device)
                    val_labels = val_labels.to(device)  # Shape: [batch_size, total_label_length]

                    # Forward pass
                    val_outputs_basic = model(val_images, val_tabular_data)

                    # Partition labels
                    val_labels_list = []
                    start_idx = 0
                    for num_categories in y_num_categories_list:
                        end_idx = start_idx + num_categories
                        task_labels_one_hot = val_labels[:, start_idx:end_idx]
                        # Convert one-hot labels to class indices
                        task_labels_indices = torch.argmax(task_labels_one_hot, dim=1)
                        val_labels_list.append(task_labels_indices)
                        start_idx = end_idx

                    # Compute loss per task
                    val_losses = []
                    for val_output, val_label in zip(val_outputs_basic, val_labels_list):
                        loss = criterion(val_output, val_label)
                        val_losses.append(loss)
                    val_total_loss = sum(val_losses)
                    val_loss += val_total_loss.item()
                    val_total_samples += val_labels.size(0)

                    # Calculate accuracy per task
                    for i, (val_output, val_label) in enumerate(zip(val_outputs_basic, val_labels_list)):
                        _, val_predicted = torch.max(val_output, dim=1)
                        correct = (val_predicted == val_label).sum().item()
                        val_correct_predictions[i] += correct

            # Average validation loss and accuracy
            val_loss /= len(val_data_loader)
            val_accuracy = [100 * val_correct_predictions[i] / val_total_samples for i in range(num_tasks)]
            avg_val_accuracy = sum(val_accuracy) / num_tasks
            print(f"Validation Loss: {val_loss:.4f}, Average Validation Accuracy: {avg_val_accuracy:.2f}%")

            # Save the best model weights based on validation accuracy
            if avg_val_accuracy > best_accuracy:
                best_accuracy = avg_val_accuracy
                best_weights = model.state_dict().copy()
                torch.save(best_weights, './models/best_model.pth')
                print("New best model saved as './models/best_model.pth'")
            model.train()  # Return to training mode

    # Epoch-level loss and accuracy
    epoch_loss = running_loss / len(train_data_loader)
    epoch_accuracy = [100 * correct_predictions[i] / total_samples for i in range(num_tasks)]
    avg_epoch_accuracy = sum(epoch_accuracy) / num_tasks
    print(f"Epoch {epoch+1}/{num_epochs} completed: Loss: {epoch_loss:.4f}, Average Accuracy: {avg_epoch_accuracy:.2f}%")
    
    # Update the scheduler for learning rate decay
    scheduler.step()

# Save the final model weights
torch.save(model.state_dict(), './models/final_model.pth')
print("Final model saved as './models/final_model.pth'")

# Ensure best weights are also saved
if best_weights is not None:
    torch.save(best_weights, './models/best_model.pth')
    print("Best model saved as './models/best_model.pth'")

Epoch [1/1], Batch [1/57626], Loss: 21.9851, Average Batch Accuracy: 45.45%
Epoch [1/1], Batch [2/57626], Loss: 20.9844, Average Batch Accuracy: 50.00%
Epoch [1/1], Batch [3/57626], Loss: 21.9843, Average Batch Accuracy: 48.48%
Epoch [1/1], Batch [4/57626], Loss: 19.9851, Average Batch Accuracy: 52.27%
Epoch [1/1], Batch [5/57626], Loss: 20.9846, Average Batch Accuracy: 52.73%
Epoch [1/1], Batch [6/57626], Loss: 20.9787, Average Batch Accuracy: 53.03%
Epoch [1/1], Batch [7/57626], Loss: 20.9826, Average Batch Accuracy: 53.25%
Epoch [1/1], Batch [8/57626], Loss: 21.9506, Average Batch Accuracy: 52.27%
Epoch [1/1], Batch [9/57626], Loss: 19.9636, Average Batch Accuracy: 53.54%
Epoch [1/1], Batch [10/57626], Loss: 20.4308, Average Batch Accuracy: 54.55%
Epoch [1/1], Batch [11/57626], Loss: 18.9683, Average Batch Accuracy: 56.20%
Epoch [1/1], Batch [12/57626], Loss: 20.9816, Average Batch Accuracy: 56.06%
Epoch [1/1], Batch [13/57626], Loss: 19.9797, Average Batch Accuracy: 56.64%


KeyboardInterrupt: 