## Demo

This notebook serves to demo the ongoing training code.

In [None]:
from morphospectro.utils.datasets import GalaxyDataset
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms
from morphospectro.utils.networks import Feedforward
from skimage import io, transform
import matplotlib.pyplot as plt

In [None]:
dataset= GalaxyDataset(spectra_file = "/home/drd13/outputs/data/raw/data/s0_spectra.h5",image_folder = "/home/drd13/outputs/data/raw/data/images")


In [None]:
test_im = dataset[0][1]
test_im= torch.tensor(test_im).unsqueeze(0)
test_im= test_im.permute(0,3,1,2)
test_im = torch.true_divide(test_im,255)
print(test_im.shape)

In [None]:
test_spec = dataset[0][0].unsqueeze(0)
test_spec.shape

In [None]:
n_batch = 64
loader = torch.utils.data.DataLoader(dataset = dataset,
                                     batch_size = n_batch,
                                     shuffle = False,
                                     drop_last=True)

In [None]:
lr = 0.0001
feedforward = Feedforward([3280,1024,512,50,512,1024,3280])
optimizer = torch.optim.Adam(feedforward.parameters(), lr=lr)

In [None]:
class FullNetwork(nn.Module):
    def __init__(self,im_network,spec_network,merged_network):
        super(FullNetwork, self).__init__()      
        self.im_network = im_network
        self.spec_network = spec_network
        self.merged_network = merged_network


    def forward(self, spec, im):
        latent_im  = self.im_network(im)
        latent_spec  = self.spec_network(spec)
        merged_latent = torch.cat((latent_im,latent_spec),dim=1)
        pred_match = merged_network(merged_latent)
        return pred_match

In [None]:
class ConvNet(nn.Module):

    def __init__(self):
        super(ConvNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square convolution
        # kernel
        self.conv1 = nn.Conv2d(3, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(61504, 120)  # 6*6 from image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 50)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

In [None]:
spec_network = Feedforward([3280,1024,512,50])
im_network = ConvNet()
merged_network = Feedforward([100,64,32,1])
full_network = FullNetwork(im_network,spec_network,merged_network)

In [None]:
optimizer_full = torch.optim.Adam(full_network.parameters(), lr=lr)

In [None]:
full_network(test_im,test_spec)

In [None]:
#loss = torch.nn.CrossEntropyLoss()
loss = torch.nn.MSELoss()

In [None]:
im.shape

In [None]:
for i in range(100):
    for (spec,im,idx) in loader:
        #optimizer.zero_grad()
        prob_match = full_network(spec,im)
        #err = loss(spec_pred,spec)
        #err.backward()
        #optimizer.step()
        print(prob_match)
        print(f"err:{err}")

In [None]:
label_real = torch.ones(n_batch,1)
label_scrambeled = torch.zeros(n_batch,1)
for i in range(100):
    for (spec,im,idx) in loader:
        #optimizer.zero_grad()
        scrambled_spec = randomize(spec)
        prob_real = full_network(spec,im)
        prob_scrambeled = full_network(scrambled_spec,im)
        err_real = loss(prob_real,label_real)
        err_fake = loss(prob_scrambeled,label_scrambeled)
        err_tot = err_real+err_fake
        err_tot.backward()
        print(prob_match)
        print(f"err:{err}")

In [None]:
prob_match.shape

In [None]:
for i in range(100):
    for (spec,im,idx) in loader:
        optimizer.zero_grad()
        spec_pred = feedforward(spec)
        err = loss(spec_pred,spec)
        err.backward()
        optimizer.step()
        print(f"err:{err}")

In [None]:
err =loss(spec_pred,spec)

In [None]:
feedforward

In [None]:
err.backward()

In [None]:
feedforward.fc[0].weight.grad