# pytorch dataloaders

This notebook is used to develop the pytorch dataset and correspond dataloaders.

pytorch dataloaders use a well-defined pytorch dataset to handle the process of generating training/testing/validation sets. The pytorch dataset is just a class that contains two methods, `__len__()` and `__getitem__`. The `len` method just returns the size of the dataset and the `getitem` method returns a single sample and its corresponding label. 

In [None]:
# %matplotlib widget
import glob
import os

from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

import numpy as np
import pandas as pd
from pyts.image import RecurrencePlot
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor

import sys
sys.path.append('/Users/ndmiles/ClassWork/FallQuarter2021/aos_c204/aos_c204_final_project')

from utils import generate_data_chunks

In [None]:
class ICMEDataset(Dataset):
    def __init__(self, icme_labels, rootdir, datadir, transform=None):

        self.rootdir = rootdir
        self.datadir = datadir
        self.df = pd.read_csv(
            f'{rootdir}/data/{icme_labels}',
            header=0, 
            parse_dates=['start_time', 'stop_time']
        )
        self.transform=transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        f = self.df.fname_img.iloc[idx]
        f_ts = self.df.fname.iloc[idx]
        label = self.df.label.iloc[idx]
        img = np.load(f'{self.datadir}/{f}')

        return img, label, f_ts

In [None]:
icme_train_dataset = ICMEDataset(
    'sta_train_set.txt', 
    rootdir=(
        '/Users/ndmiles/ClassWork/FallQuarter2021/'
        'aos_c204/aos_c204_final_project/'
    ),
    datadir=(
        '/Users/ndmiles/ClassWork/FallQuarter2021/'
        'aos_c204/aos_c204_final_project/data/sta_chunks/'
    )
)

In [None]:
icme_test_dataset = ICMEDataset(
    'sta_test_set.txt', 
    rootdir=(
        '/Users/ndmiles/ClassWork/FallQuarter2021/'
        'aos_c204/aos_c204_final_project/'
    ),
    datadir=(
        '/Users/ndmiles/ClassWork/FallQuarter2021/'
        'aos_c204/aos_c204_final_project/data/sta_chunks/'
    )
)

In [None]:
icme_val_dataset = ICMEDataset(
    'sta_validation_set.txt', 
    rootdir=(
        '/Users/ndmiles/ClassWork/FallQuarter2021/'
        'aos_c204/aos_c204_final_project/'
    ),
    datadir=(
        '/Users/ndmiles/ClassWork/FallQuarter2021/'
        'aos_c204/aos_c204_final_project/data/sta_chunks/'
    )
)

In [None]:
train_loader = DataLoader(
    icme_train_dataset, 
    batch_size=1, 
    shuffle=True,
    num_workers=0
)

In [None]:
test_loader = DataLoader(
    icme_test_dataset, 
    batch_size=1, 
    shuffle=True,
    num_workers=0
)

In [None]:
val_loader = DataLoader(
    icme_val_dataset, 
    batch_size=1, 
    shuffle=True,
    num_workers=0
)

In [None]:
for img, label, fname_ts in train_loader:
    print(img.shape, fname_ts)

In [None]:
t_df = pd.read_csv('../data/sta_chunks/sta_ts_interval_2012-04-16_12_02_00_to_2012-04-19_12_02_00.txt', header=0, index_col=0, parse_dates=True)

In [None]:
t_df.head()

In [None]:
t_df.shape

In [None]:
t_df.plot(y='BTOTAL')

In [None]:
img1, label1 = next(iter(train_loader))
img2, label2 = next(iter(train_loader))

In [None]:
def plot_img(img, label, cols=None):
    img_dict = dict()
#     img, label = next(iter(loader))
    if cols is None:
        cols = [f'{i:0.0f}' for i in range(len(img[0]))]
    for i, col in enumerate(cols):
        img_dict[col] = [img[0][i].numpy()]
    fig = generate_data_chunks.visualize_chunk_img(img_dict, icme=label.numpy()[0])
    return fig

In [None]:
pdf = PdfPages('icmes_images_trainining.pdf', 'a')

In [None]:
cols = (
    'BTOTAL',
    'BX(RTN)',
    'BY(RTN)',
    'BZ(RTN)',
    'VP_RTN',
    'NP'
    'TEMPERATURE',
    'BETA'
)

In [None]:
for img, label in train_loader:
    if label == 1:
        fig = plot_img(img, label, cols=cols)
        pdf.savefig(bbox_inches='tight', dpi=150)
        plt.close(fig)
pdf.close()