# SANS: Using CNN to predict model ID
We will utilize the zenodo [repository](https://zenodo.org/records/10119316) as our data for this project. There is no need to locally download the data as we will be using the library `fsspec` to work directly with the link of the `.h5` files. This will not load the data from the remote files into the user's working memory.

In [1]:
%colors lightbg
%matplotlib inline

In [2]:
import h5py
import fsspec
import torch
from torch import nn
import pickle
import torchvision
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch.nn.functional as F
import torch.optim as optim

ModuleNotFoundError: No module named 'torch'

## load data

In [None]:
train_url = "https://zenodo.org/records/10119316/files/train.h5"
test_url = "https://zenodo.org/records/10119316/files/test.h5"
val_url = "https://zenodo.org/records/10119316/files/val.h5"
remote_f = fsspec.open(train_url, mode="rb")
if hasattr(remote_f, "open"):
    remote_f = remote_f.open()

In [None]:
class H5Dataset(Dataset):
    def __init__(self, h5_path, transforms=None):
        self.h5_file = h5py.File(h5_path, "r")
        self.transform = transforms

    def __getitem__(self, index):
        sample = self.h5_file["data"][index]
        if self.transform is not None:
            sample = self.transform(sample)
        return (
            sample,
            int(self.h5_file["target"][index]),
        )

    def __len__(self):
        return self.h5_file["target"].size

In [None]:
transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Resize((180, 180), antialias=True),
        torch.nn.ReLU(inplace=True),  # remove negative values if any
        torchvision.transforms.Lambda(lambda x: torch.log(x + 1.0)),
        torchvision.transforms.Lambda(
            lambda x: x / torch.max(x) if torch.max(x) > 0 else x
        ),
    ]
)
train_data = H5Dataset(remote_f)


In [None]:
print(len(train_data))

### Plotting a typical 2D intensity array

In [None]:
idx = 537
psd_array = train_data[idx][0]
target = train_data[idx][1]
plt.figure()
plt.imshow(psd_array)
plt.title(target)
plt.show()

### exploring the model names

Manually add the `.pkl` files inside the directory `sas_helper/` as some of the files are too large to be uploaded on a GitHub repo. These files have been purposefully added to `.gitignore` due to this reason.

In [None]:
with open("sas_helper/model_names.pkl", "rb") as pf:    # 
    model_names = pickle.load(pf)
print(model_names)

In our case, the model name is given as

In [None]:
model_names[39]

### pytorch dataloader

Here is a code template you could use to put this dataset in a dataloader for `pytorch`, which could directly go in a training loop.

In [None]:
train_dataloader = DataLoader(
    train_data,
    batch_size=32, # play with this also
    num_workers=2, #modify to your cpus available!
    shuffle=True,
)

In [None]:
dataiter = iter(train_dataloader)
images, labels = next(dataiter)

### loading instrument parameters

In [None]:
inst_params = {}
for partition in ["test", "train", "val"]:
    with open(f"sas_helper/inst_params_{partition}.pkl", "rb") as pf:
        inst_params[partition] = pickle.load(pf)

In [None]:
inst_params['train'].shape

In [None]:
inst_params['train'][0]

We loaded all the instrument parameters inside the `inst_params` dictionary, with the corresponding partition as key. The 10 parameters that are one-hot encoded are:

In [None]:
inst_params_names  = ['Lam_4.5', 'Lam_6.0', 'zdepth_0.001', 'zdepth_0.002', 'InstSetting_1', 'InstSetting_2', 'InstSetting_3', 'SlitSetting_1', 'SlitSetting_2', 'SlitSetting_3']
print(inst_params_names)

- Lam = Lambda, wavelength of monochromatized neutrons in Angstroms. Two possible values [4.5, 6.0]
- zdepth = sample thickness, Two possible values [0.001, 0.002]
- InstSetting = Instrument setting. 3 possible values [1, 2, 3]
- SlitSetting = slit (collimation) setting, 3 possible values [1, 2, 3]

Again, variables are one-hot encoded. Order of columns matters. This means that case 0 (described above) was measured with a wavelength of 4.5, a sample thickness of 0.001, and a slit setting of 2. (check)

This matrix can be used as **input features** for the regression or classification task.


## model

In [None]:
if torch.cuda.device_count() > 0:
    device=torch.device("cuda:0")   # Works for NVidia and AMD GPUs
elif torch.mps.device_count() > 0:
    device=torch.device("mps:0")    # Metal Performance Shaders backend for Mac
else:
    device=torch.device("cpu")

In [None]:
class model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 8, 2)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(8, 64, 2)
        self.fc1 = nn.Linear(64 * 2 * 2, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 46)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x))) 
        x = torch.flatten(x, 1) 
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = model()
model = model.to(device)

In [None]:
weights = list(model.parameters())
print(model)
print(len(weights))


In [None]:
#output = model(train_data[0][0])
output = model(torch.randn(2, 3, 11, 11))
print(output)

In [None]:
loss_func = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.8)

In [None]:
outputs = model(images)
loss = loss_func(outputs, labels)
loss.backward()

In [None]:
PATH = './saved_model.pth'
torch.save(model.state_dict(), PATH)