In [1]:
# Automatically reload imported modules
%load_ext autoreload
%autoreload 2

In [2]:
import json
import os

import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.utils.data import TensorDataset, DataLoader
import torch.nn as nn

import sys
sys.path.append('..')  # Add parent directory to Python path
from util.data_preprocessing import pad_data_to_30x30, encode_grid, augment_data
from util.model import EnhancedMultiChannelCNN
from util.training import train_model, train_memory
from util.visualization import visualize_loss, visualize_memory, visualize_prediction, visualize_actual
from util.model_utils import initialize_model, initialize_memory
from util.data_loader import prepare_data, read_data

In [3]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

In [34]:
# Get list of all json files in data/training directory
data_dir = '../data/training'
json_files = [f for f in os.listdir(data_dir) if f.endswith('.json')]

# Get file name based on index i
i = 12  # Can be changed to select different files
file_name = json_files[i]


In [None]:
file_name

In [None]:
def train(file_name, model, memory, criterion, epochs=1):
  dataloader = prepare_data(file_name)

  loss_history = []

  for i in range(epochs):
    loss_history.extend(train_memory(model, memory, dataloader, criterion))
    loss_history.extend(train_model(model, memory, dataloader, criterion))

  visualize_loss(loss_history)


model, criterion = initialize_model()
memory = initialize_memory()
train(file_name, model, memory, criterion, epochs=10)

In [None]:
visualize_memory(memory)

In [41]:
# Save the trained model
torch.save(model.state_dict(), 'linear_static_dictionary.pth')

In [None]:
data = read_data(file_name)

x_test = data['test'][0]
visualize_actual(x_test)
visualize_prediction(x_test['input'], model, memory)
