In [1]:
import os
import torch
import numpy as np
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

In [2]:
class MBDataSet(Dataset):
    
    def __init__(self, root): #I think I got this, but ask about other arguments
        self.data = np.load(root)
        self.max_data = np.max(self.data, axis=0)
        self.min_data = np.min(self.data, axis=0)
        self.standardized_data = 2*(self.data - self.min_data)/(self.max_data - self.min_data) - 1
        
    def __len__(self):
        return len(self.data) #need to verify the index here
    
    def __getitem__(self, index):
        x = self.standardized_data[index]
        return torch.from_numpy(x).float() #Ask here

In [3]:
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()      
        self.dim = dim

    def forward(self, x):
        half_dim = self.dim // 2
        emb = np.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

In [4]:
class MLP(nn.Module):
    
    def __init__(self):
        super().__init__()                   #To inheret all the methods and properties of the parent class
        self.layers = nn.Sequential(
            #nn.Flatten(),                    #Do I need this?
            nn.Linear(3,8),                     #What are the dimensions? Do they matter?
            nn.ReLU(),
            nn.Linear(8,16),
            nn.ReLU(),
            nn.Linear(16,32),
            nn.ReLU(),
            nn.Linear(32,64),
            nn.ReLU(),
            nn.Linear(64,128),
            nn.ReLU(),
            nn.Linear(128,64),
            nn.ReLU(),
            nn.Linear(64,32),
            nn.ReLU(),
            nn.Linear(32,16),
            nn.ReLU(),
            nn.Linear(16,8),
            nn.ReLU(),
            nn.Linear(8,3),
        )
             
    def forward(self,x):                    #This feeds the data into the layers and returns the output
        return self.layers(x)

In [5]:
class Reparametrize(nn.Module):
    
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.dim_in = dim_in
        self.dim_out = dim_out
        
        self.layers = nn.Sequential(
            nn.Linear(dim_in, dim_out),
            nn.ReLU()
        )
        
    def forward(self, x):
        return self.layers(x)

In [6]:
class CombinedBlock(nn.Module):
    
    def __init__(self, data_dim_in, data_dim_out, time_dim_in, time_dim_out):
        super().__init__()
        self.data_dim_in = data_dim_in
        self.data_dim_out = data_dim_out
        self.time_dim_in = time_dim_in
        self.time_dim_out = time_dim_out
        
        
        self.data_layer = nn.Sequential(
            nn.Linear(data_dim_in, data_dim_out),
            nn.ReLU()
        )
        
        self.time_layer = nn.Sequential(
            nn.Linear(time_dim_in, time_dim_out),
            nn.ReLU()
        )
        
        self.combined_data_layer = nn.Sequential(
            nn.Linear(data_dim_out, data_dim_out),
            nn.ReLU()
        )
    
    def forward(self, x, time_emb):
        h_data = self.data_layer(x)
        h_time = self.time_layer(t_emb)
        combined_output = h_data + h_time
        out = self.combined_data_layer(combined_output)
        return out

class MLPModule(nn.Module):
    
    def __init__(self, dim_list = [4,8,16,32,64,128]):
        super().__init__()
        self.block_list = nn.ModuleList()
        
        upsample, downsample = dim_list, dim_list[::-1]
        
        for data_dim_in, data_dim_out in zip(upsample[:-1], upsample[1:]):
            self.block_list.append(CombinedBlock(data_dim_in, data_dim_out, dim_list[0], data_dim_out))
        for data_dim_in, data_dim_out in zip(downsample[:-1], downsample[1:]):
            self.block_list.append(CombinedBlock(data_dim_in, data_dim_out, dim_list[0], data_dim_out))
            
            
    def forward(self, x, t_emb):
        
        for block in self.block_list:
            x = block(x, t_emb)
        return x

In [7]:
# if __name__ == '__main__':        What is this??
  
#   # Set fixed random number seed
#   torch.manual_seed(42)

In [8]:
training_data=MBDataSet('data.npy')
train_dataloader = DataLoader(training_data, batch_size=100, shuffle=True)

In [9]:
rp_i = Reparametrize(3,4)
rp_f = Reparametrize(4,3)
spm = SinusoidalPosEmb(4)
mlp = MLPModule()

loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-4)     #preferences?

In [10]:
for epoch in range(5):             #Preferences?
    
    for i, data in enumerate(train_dataloader):
        inputs = rp_i(data)
        t = torch.rand(1)
        t_emb = spm(t)
        targets = torch.normal(0, 1, size=(100,3))
        optimizer.zero_grad()
        outputs_prime = mlp(inputs, t_emb) #inputs should be corrupted data not original data
        outputs = rp_f(outputs_prime)
        loss = loss_function(outputs, targets)
        
        
        loss.backward()
        optimizer.step()
        