In [7]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
import random
import glob
import pickle
import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from models import CNN_LSTM

In [8]:
# Load a sample of your data
sample_data = np.load('./data/PROCESSED_III/2018_13_115.npy')  # Replace 'sample_data.npy' with the path to your data file

# Check the shape of the sample data
print("Shape of sample data:", sample_data.shape)

Shape of sample data: (38, 1, 128, 9)


In [9]:
# Define generator function
def generator(IDs, yields, batch_size, cutoff=None):
    def load_data(ID):
        try:
            data = np.load('/data/PROCESSED_III/' + ID + '.npy')
            return data
        except Exception as e:
            # print('Error loading data:', e)
            return None

    while True:
        batch_features = np.zeros((batch_size, 38, 1, 128, 9)) if cutoff is None else np.zeros((batch_size, cutoff, 1, 128, 9))
        batch_yields = np.zeros(batch_size)
        
        for i in range(batch_size):
            index = random.choice(range(len(IDs)))
            ID = IDs[index]
            data = load_data(ID)
            
            if data is not None:
                if cutoff is not None:
                    if not np.isnan(data).any():
                        batch_features[i, :, :, :, :] = data[:cutoff, :, :, :]
                        batch_yields[i] = yields[ID]
                    else:
                        print('Data contains NaN values:', ID)
                else:
                    batch_features[i, :, :, :, :] = data
                    batch_yields[i] = yields[ID]
            # else:
            #     print('Failed to load data:', ID)
    
        yield torch.tensor(batch_features, dtype=torch.float32), torch.tensor(batch_yields, dtype=torch.float32)

model_functions = {
    'CNN_LSTM': CNN_LSTM,
    # 'SepCNN_LSTM': SepCNN_LSTM,
    # 'CONVLSTM': CONVLSTM,
    # 'CONV3D': CONV3D,
    # 'CONVLSTM_CONV3D': CONVLSTM_CONV3D
}


print("CNN", CNN_LSTM)

# Datasets
yields = pickle.load(open('data/yields.p', 'rb'))

# Retrieve the first 500 entries from yields['train']
first_500_train = {key: yields['train'][key] for key in list(yields['train'])[:500]}

print(len(yields['train']), len(yields['validation']))

# Generators
training_generator = generator(list(yields['train'].keys()), yields['train'], 16)
validation_generator = generator(list(yields['validation'].keys()), yields['validation'], 16)

for model_name, model_function in model_functions.items():
    model = model_function(dimensions=[38, 1, 128, 9])
    optimizer = torch.optim.Adam(model.parameters())
    criterion = nn.MSELoss()
    # earlystop = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=5, verbose=1, mode='auto')
    # callbacks_list = [earlystop]
    
    for epoch in range(10):
        model.train()
        train_losses = []
        for batch_data, batch_labels in tqdm.tqdm(training_generator, desc=f"Epoch {epoch+1}/{100}"):
            optimizer.zero_grad()
            outputs = model(batch_data)
            loss = criterion(outputs, batch_labels.unsqueeze(1))
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())
        
        model.eval()
        val_losses = []
        with torch.no_grad():
            for val_data, val_labels in validation_generator:
                val_outputs = model(val_data)
                val_loss = criterion(val_outputs, val_labels.unsqueeze(1))
                val_losses.append(val_loss.item())
        print(f"Epoch {epoch+1}/{100}, Train Loss: {np.mean(train_losses):.4f}, Val Loss: {np.mean(val_losses):.4f}")
        
        # earlystop(np.mean(val_losses), model)  # Check early stopping criterion
        
        # if earlystop.early_stop:
        #     print("Early stopping")
        #     break
    
    # Save the model
    torch.save(model, f'{model_name}.pt')


CNN <class 'models.CNN_LSTM'>
6559 1511


Epoch 1/100: 10it [00:07,  1.37it/s]


KeyboardInterrupt: 