In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import pickle

from src.custom_dataset import CustomDataset
from src.handler import Handler

In [2]:
images_data_path = './data/archive/images/images'
x_train_file_path = './data/x_train.csv'
y_train_file_path = './data/y_train.csv'
x_val_file_path = './data/x_val.csv'
y_val_file_path = './data/y_val.csv'

In [3]:
batch_size = 1024
num_epochs = 2
checkpoint_interval = 200
validation_check_steps = 1

In [4]:
train_dataset = CustomDataset(x_path=x_train_file_path, y_path=y_train_file_path, image_folder_path=images_data_path)
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

val_dataset = CustomDataset(x_path=x_val_file_path, y_path=y_val_file_path, image_folder_path=images_data_path)
val_data_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

In [None]:
x_num_categories_list=train_dataset.get_x_num_categories_list()
y_num_categories_list=train_dataset.get_y_num_categories_list()
label_encoders = train_dataset.get_label_encoders()

In [None]:
with open('./data/label_encoders.pkl', 'wb') as f:
	pickle.dump(label_encoders, f)

with open('./data/x_num_categories_list.pkl', 'wb') as f:
	pickle.dump(x_num_categories_list, f)

with open('./data/y_num_categories_list.pkl', 'wb') as f:
	pickle.dump(y_num_categories_list, f)

In [5]:
initial_lr = 1e-3
min_lr = 1e-5
weight_decay_value = 1e-4

criterion = nn.CrossEntropyLoss()
model = Handler(x_num_categories_list=x_num_categories_list, y_num_categories_list=y_num_categories_list)
optimizer = optim.AdamW(model.parameters(), lr=initial_lr, weight_decay=weight_decay_value)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=min_lr)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Training loop with validation and weight decay
best_accuracy = 0.0
best_weights = None

for epoch in range(num_epochs):
	model.train()
	running_loss = 0.0
	correct_predictions = 0
	total_samples = 0
	
	# Training phase
	for batch_idx, (images, tabular_data, labels) in enumerate(train_data_loader):
		# Move data to GPU if available
		images, tabular_data = images.to(device), tabular_data.to(device)
		labels = labels.to(device)
		if labels.dim() == 2 and labels.size(1) > 1:  # Convert one-hot to indices if necessary
			labels = torch.argmax(labels, dim=1)

		# Zero the gradients
		optimizer.zero_grad()
		
		# Forward pass
		outputs = model(images, tabular_data)
		outputs = torch.cat(outputs, dim=1)
		
		# Calculate loss
		loss = criterion(outputs, labels)
		
		# Backward pass and optimize
		loss.backward()
		optimizer.step()
		
		# Update running loss
		running_loss += loss.item()
		
		# Calculate accuracy
		_, predicted = torch.max(outputs, 1)
		correct_predictions += (predicted == labels).sum().item()
		total_samples += labels.size(0)
		
		# # Save model checkpoint periodically
		# if (total_samples + 1) % checkpoint_interval == 0:
		# 	checkpoint_path = f'model_checkpoint_{total_samples}.pth'
		# 	torch.save(model.state_dict(), checkpoint_path)
		# 	print(f"Checkpoint saved at '{checkpoint_path}' for {total_samples} samples.")
		
		# Print training stats every step
		batch_accuracy = 100 * correct_predictions / total_samples
		print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}/{len(train_data_loader)}], "
				f"Loss: {loss.item():.4f}, Batch Accuracy: {batch_accuracy:.2f}%")

		# Validation evaluation every validation_check_steps steps
		if (batch_idx + 1) % validation_check_steps == 0:
			model.eval()
			val_loss = 0.0
			val_correct_predictions = 0
			val_total_samples = 0
			with torch.no_grad():
				for val_images, val_tabular_data, val_labels in val_data_loader:
					val_images, val_tabular_data, val_labels = (
						val_images.to(device), 
						val_tabular_data.to(device), 
						val_labels.to(device)
					)
					if val_labels.dim() == 2 and val_labels.size(1) > 1:
						val_labels = torch.argmax(val_labels, dim=1)

					val_outputs = model(val_images, val_tabular_data)
					val_outputs = torch.cat(val_outputs, dim=1)

					# Calculate loss
					val_loss += criterion(val_outputs, val_labels).item()

					# Calculate accuracy
					_, val_predicted = torch.max(val_outputs, 1)
					val_correct_predictions += (val_predicted == val_labels).sum().item()
					val_total_samples += val_labels.size(0)

			# Average validation loss and accuracy
			val_loss /= len(val_data_loader)
			val_accuracy = 100 * val_correct_predictions / val_total_samples
			print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")

			# Save the best model weights based on validation accuracy
			if val_accuracy > best_accuracy:
				best_accuracy = val_accuracy
				best_weights = model.state_dict().copy()
				torch.save(best_weights, './models/best_model.pth')
				print("New best model saved as './models/best_model.pth'")
			model.train()  # Return to training mode

	# Epoch-level loss and accuracy
	epoch_loss = running_loss / len(train_data_loader)
	epoch_accuracy = 100 * correct_predictions / total_samples
	print(f"Epoch {epoch+1}/{num_epochs} completed: Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%")
	
	# Update the scheduler for learning rate decay
	scheduler.step()

# Save the final model weights
torch.save(model.state_dict(), './models/final_model.pth')
print("Final model saved as './models/final_model.pth'")

# Ensure best weights are also saved
if best_weights is not None:
	torch.save(best_weights, './models/best_model.pth')
	print("Best model saved as './models/best_model.pth'")    