In [None]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from ase.io import read
import matplotlib.pyplot as plt
from ase.visualize.plot import plot_atoms
import mogli
import os
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms, utils
from torch.utils.data import  DataLoader
import torch.nn as nn   
from src.Dataloaders import ImageDataset
import pytorch_lightning as pl

In [None]:
transform = transforms.Compose([
    transforms.Resize((512,512)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
    
dataset = ImageDataset("Datasets/3d",transform)
# dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)

proportions = [.75, .10, .15]
lengths = [int(p * len(dataset)) for p in proportions]
lengths[-1] = len(dataset) - sum(lengths[:-1])

tr_dataset, vl_dataset, ts_dataset = torch.utils.data.random_split(dataset, lengths)

In [None]:
# Multichannel 2D CNN
#the CNN consisted of three independent channels of
#six 3 × 3 ReLU-based convolutional layers with varying filter sizes
#(16, 32, 64, 128, 256, 256) and a max-pooling layer of size 2 × 2 per filter
class module(nn.Module):
    def __init__(self,input_channels, output_channels, kernel_size):
        super(module, self).__init__()
        self.conv = nn.Conv2d(input_channels, output_channels, kernel_size)
        self.maxpool = nn.MaxPool3d(2)
        self.relu = nn.ReLU()
    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        x = self.maxpool(x)
        return x

class CNN(nn.Module):

    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = module(3, 16, 3)
        self.conv2 = module(8, 32, 3)
        self.conv3 = module(16, 64, 3)
        self.conv4 = module(32, 128, 3)
        self.conv5 = module(64, 256, 3)
        self.conv6 = module(128, 256, 3)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)
                
        return x

x = torch.randn(1, 3, 512, 512)
model = CNN()
print(model(x).shape)

In [None]:
class GapPrediction(nn.Module):
    def __init__(self):
        super(GapPrediction, self).__init__()
        self.cnn1 = CNN()
        self.cnn2 = CNN()
        self.cnn3 = CNN()
        self.cnn4 = CNN()
        self.cnn5 = CNN()
        self.cnn6 = CNN()
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.bn = nn.BatchNorm1d(27648)
        self.linear = nn.Linear(27648, 1)
        # self.out = nn.Linear(1, 1)
        
        # OUPUT SINGLE NUMBER GAP

    def forward(self, x,y,z,w,q,e):
        
        x1 = self.cnn1(x)
        x2 = self.cnn2(y)
        x3 = self.cnn3(z)
        x4 = self.cnn4(w)
        x5 = self.cnn5(q)
        x6 = self.cnn6(e)
        
        x = torch.cat((x1,x2,x3,x4,x5,x6),1)
        x = x.flatten(start_dim=1)
        # print(x.shape)
        x = self.relu1(x)
        x = self.relu2(x)
        x = self.bn(x)
        # print(x.shape)
        x = self.linear(x)
        # x = self.out(x)
        return x

model = GapPrediction()
x = torch.randn(2,3, 512, 512)
print(model(x,x,x,x,x,x))


In [None]:
class BANGAP(pl.LightningModule):
    def __init__(self,tr_dataset,vl_dataset,ts_dataset,batch_size=32):
        super(BANGAP, self).__init__()
        self.model = GapPrediction()
        self.loss = nn.MSELoss()
        self.batch_size=batch_size
        self.tr_dataset = tr_dataset
        self.tr_dataset_loader = None
        self.vl_dataset = vl_dataset
        self.vl_dataset_loader = None
        self.ts_dataset = ts_dataset
        self.ts_dataset_loader = None

    def forward(self, x,y,z):
        return self.model(x,y,z)
    def training_step(self, batch, batch_idx):  
        x, y, z,w,q,e, label = batch
        output = self.model(x, y, z,w,q,e)
        loss = self.loss(output, label.reshape(-1,1))
        self.log('train_loss', loss)
        return loss
    def validation_step(self, batch, batch_idx):
        x, y, z,w,q,e, label = batch
        output = self.model(x, y, z,w,q,e)
        loss = self.loss(output, label.reshape(-1,1))
        self.log('val_loss', loss)
        return loss
    def test_step(self, batch, batch_idx):
        x, y, z,w,q,e, label = batch
        output = self.model(x, y, z,w,q,e)
        loss = self.loss(output, label.reshape(-1,1))
        self.log('test_loss', loss)
        return loss
    
    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=1e-3)
    
    def train_dataloader(self):
        if self.tr_dataset_loader is None:
            self.tr_dataset_loader = DataLoader(
                self.tr_dataset, batch_size=self.batch_size, shuffle=False, num_workers=0)

        return self.tr_dataset_loader

    def val_dataloader(self):
        if self.vl_dataset_loader is None:
            self.vl_dataset_loader = DataLoader(
                self.vl_dataset, batch_size=self.batch_size, shuffle=False, num_workers=0)
        return self.vl_dataset_loader

    def test_dataloader(self):
        if self.ts_dataset_loader is None:
            self.ts_dataset_loader = DataLoader(
                self.ts_dataset, batch_size=self.batch_size, shuffle=False, num_workers=0)
        return self.ts_dataset_loader

    
    

dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    
model = BANGAP(tr_dataset,vl_dataset,ts_dataset,batch_size=32).to(dev)
model = model.to(dev)
# model.load_state_dict(torch.load("model1.pt"))
trainer = pl.Trainer(gpus=1, max_epochs=10)
trainer.fit(model)

In [None]:
torch.save(model.state_dict(), "model2.pt")

In [None]:
dataloader = DataLoader(ts_dataset, batch_size=2, num_workers=0,shuffle=True)
model = model.to(dev)
model.load_state_dict(torch.load("model2.pt"))
for batch in dataloader:
    x, y, z,_,_,_, label = batch
    output = model(x.to(dev),y.to(dev),z.to(dev))
    for real , pred in zip(label,output):
        print("real ",real.item(),"pred ",pred.item())
    break
