# A CNN-LSTM framework for the solar wind density forecasting
## ConvLSTM training
In this notebook we train a ConvLSTM network to predict solar wind densities (electrons + protons)


#### Notebook Contributors
* Andrea Giuseppe Di Francesco -- email: difrancesco.1836928@studenti.uniroma1.it
* Massimo Coppotelli -- email: coppotelli.1705325@studenti.uniroma1.it

In [None]:
# !pip install pandas
# !pip install numpy
# !pip install torch
# !pip install matplotlib
# !pip install torchvision
# !pip install wandb
# !pip3 install pytorch-lightning==1.5.10

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import wandb
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback

# Personal files
from convlstm import *
from utils import *
from init import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
if wb:
    wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdifra00[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
wind_dataset = pd.read_csv('./datasets/wind_dataset_1d_res.csv', index_col = 0)
sun_dataset = load_data('./datasets/ARI_image_dataset.json')

## Pytorch lightning code 
- Need of a collate function to preprocess data: sun_images are expressed as lists in a json files, thus we need a preprocessing before feed them into the ConvLSTM.
- Then we define a Lightning DataModule, and finally the Lightining Module for the model training.

In [4]:
class pl_Dataset_(pl.LightningDataModule):

    def __init__(self,  dataset, bs):
      

      self.train_set = dataset.loc[0:round(len(dataset)*train_split)]
      self.val_set = dataset.loc[round(len(dataset)*train_split)+1: round(len(dataset)*train_split) + round(len(dataset)*val_split)]

      self.bs = bs

    def setup(self, stage = None):
        if stage == 'fit':
            self.train_dataset = DataSet(self.train_set)
        elif stage == 'test':
            self.val_dataset = DataSet(self.val_set)
            

    def train_dataloader(self, *args, **kwargs):
        return DataLoader(self.train_dataset, batch_size = self.bs, shuffle = True, collate_fn = collate)

    def val_dataloader(self, *args, **kwargs):
        return DataLoader(self.val_dataset, batch_size = self.bs, shuffle = False, collate_fn = collate)


In [5]:
def collate(batch):
    ''' This is the collate_function for the DataLoader module of pytorch, indeed we just give the wind dataset, since it defines the dataset length, 
        and then we define the additional modules 
        INPUT: batch: batch_sizex(timestamp, proton_density, electron_density),
        OUTPUT: tensor: batch_size x time_steps x image_channels x image_height x image_width,  batch_size x proton_density x electron_density
    '''
    tensor = torch.zeros((len(batch), H+1, 1, 224, 224))
    density = torch.zeros((len(batch), 2)) # Proton and Electron Density tensor


    for sample in batch:
        d = 0

        requested_images = get_history_images(sample[0], H, D, resolution) # sample[0] corrensponds to the date of the solar wind prediction date.

        mid_tensor = torch.zeros((len(requested_images), 1, 224, 224))
        density_pair = torch.tensor([sample[1], sample[2]])

        for image_idx in range(len(requested_images)):

            image_tensor = torch.tensor(sun_dataset[requested_images[image_idx]])

            mid_tensor[image_idx] = image_tensor
        
        tensor[d] = mid_tensor
        density[d] = density_pair

        d += 1
        
    
    return tensor, density

In [6]:
SettingData = pl_Dataset_(wind_dataset, batch_size)

SettingData.setup('fit')


In [7]:
for stronzo in SettingData.train_dataloader():
    input = stronzo[0]
    break
input

tensor([[[[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]],


         [[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]],


         [[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]],


         [[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0

In [9]:
model = HeliosNet(n_channels, n_hidden_channels, kernel_size, batch_first, bias)