In [1]:
import os
import torch
import numpy as np
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
torch.cuda.is_available()

True

In [2]:
class MBDataSet(Dataset):

    def __init__(self, root): #I think I got this, but ask about other arguments
        self.data = np.load(root)
        self.max_data = np.max(self.data, axis=0)
        self.min_data = np.min(self.data, axis=0)
        self.standardized_data = 2*(self.data - self.min_data)/(self.max_data - self.min_data) - 1

    def __len__(self):
        return len(self.data) #need to verify the index here

    def __getitem__(self, index):
        x = self.standardized_data[index]
        return torch.from_numpy(x).float() #Ask here

In [3]:
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        half_dim = self.dim // 2
        emb = np.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

In [4]:
class Reparametrize(nn.Module):

    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.dim_in = dim_in
        self.dim_out = dim_out

        self.layers = nn.Sequential(
            nn.Linear(dim_in, dim_out),
            nn.ReLU()
        )

    def forward(self, x):
        return self.layers(x)

In [5]:
class Reparametrize2(nn.Module):

    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.dim_in = dim_in
        self.dim_out = dim_out

        self.layers = nn.Sequential(
            nn.Linear(dim_in, dim_out)
        )

    def forward(self, x):
        return self.layers(x)

In [6]:
class CombinedBlock(nn.Module):

    def __init__(self, data_dim_in, data_dim_out, time_dim_in, time_dim_out):
        super().__init__()
        self.data_dim_in = data_dim_in
        self.data_dim_out = data_dim_out
        self.time_dim_in = time_dim_in
        self.time_dim_out = time_dim_out


        self.data_layer = nn.Sequential(
            nn.Linear(data_dim_in, data_dim_out),
            nn.Mish()
        )

        self.time_layer = nn.Sequential(
            nn.Linear(time_dim_in, time_dim_out),
            nn.Mish()
        )

        self.combined_data_layer = nn.Sequential(
            nn.Linear(data_dim_out, data_dim_out),
            nn.Mish()
        )

    def forward(self, x, time_emb):
        h_data = self.data_layer(x)
        h_time = self.time_layer(t_emb)
        combined_output = h_data + h_time
        out = self.combined_data_layer(combined_output)
        return out

class MLPModule(nn.Module):

    def __init__(self, dim_list = [4,8,16,32]):
        super().__init__()
        self.block_list = nn.ModuleList()
#         self.block_list.append(Reparametrize(3,4))
        upsample, downsample = dim_list, dim_list[::-1]

        for data_dim_in, data_dim_out in zip(upsample[:-1], upsample[1:]):
            self.block_list.append(CombinedBlock(data_dim_in, data_dim_out, dim_list[0], data_dim_out))
        for data_dim_in, data_dim_out in zip(downsample[:-1], downsample[1:]):
            self.block_list.append(CombinedBlock(data_dim_in, data_dim_out, dim_list[0], data_dim_out))
#         self.block_list.append(Reparametrize(4,3))


    def forward(self, x, t_emb):

        for block in self.block_list:
            x = block(x, t_emb)
        return x

In [None]:
training_data=MBDataSet('/content/data.npy')
batch_size=100
train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
training_data.__getitem__(0)

tensor([ 0.7915,  0.4727, -1.0000])

In [None]:
def loss_function(outputs, targets):
    return ((outputs - targets)**2).mean()

In [None]:
rp_i = Reparametrize(3,4).to('cuda')
rp_f = Reparametrize2(4,3).to('cuda')
spm = SinusoidalPosEmb(4)
mlp = MLPModule().to('cuda')
rp_i.train()
mlp.train()
rp_f.train()
params = list(mlp.parameters()) + list(rp_i.parameters()) + list(rp_f.parameters())
# loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(params, lr=1e-3)     #preferences?

In [None]:
beta_start=0.0001
beta_end=0.02
num_diffusion_steps=1000

times = torch.round(torch.linspace(0,1,num_diffusion_steps), decimals=3)
betas = torch.linspace(beta_start,beta_end,num_diffusion_steps) # from implementation in Ho, et. al.
# we use betas (which are just rescaled times) to parameterize the diffusion process

loss_arr = []

for epoch in range(15):             #Preferences?

    print(f'Starting epoch {epoch+1}')
    current_loss = 0.0

    for i, data in enumerate(train_dataloader):
        time_indices = torch.randint(num_diffusion_steps, (batch_size,1))
        sampled_times = times[time_indices]
        # sampling a different time for each data sample in batch
        t_emb = spm(sampled_times).to('cuda')

        alphas=(1-betas).cumprod(dim=0)[time_indices].view(batch_size,1)
        # selecting alpha/beta values corresponding to times
        noise = torch.normal(0, 1, size=data.shape)
        corrupted_data = alphas**0.5*data + (1-alphas)**0.5*noise
        noise = noise.to('cuda')
        corrupted_data = corrupted_data.to('cuda')
        # corrupting data according to timestep using implementation in DDPM, Ho, et. al.

        optimizer.zero_grad()

        inputs = rp_i(corrupted_data)
        outputs_prime = mlp(inputs, t_emb)
        outputs = rp_f(outputs_prime)

        loss = loss_function(outputs, noise)

        loss.backward()

        optimizer.step()

        # loss_arr.append(loss.detach())
        # print(loss_arr())
        # if i % 5 == 0:
        #     print(loss)

        current_loss += loss.item()
        if i % 1000 == 999:
            print('Loss after mini-batch %5d: %.3f' %
                (i + 1, current_loss / 1000))
            current_loss = 0.0

print('Training is complete')


Starting epoch 1
Loss after mini-batch  1000: 0.776
Loss after mini-batch  2000: 0.470
Loss after mini-batch  3000: 0.433
Loss after mini-batch  4000: 0.429
Loss after mini-batch  5000: 0.427
Loss after mini-batch  6000: 0.428
Loss after mini-batch  7000: 0.424
Loss after mini-batch  8000: 0.425
Loss after mini-batch  9000: 0.419
Loss after mini-batch 10000: 0.419
Starting epoch 2


In [None]:
rp_i.eval()
mlp.eval()
rp_f.eval()

eta=1

sampling_batch_size = 10000

x = torch.randn(
    sampling_batch_size,
    3,
    dtype=torch.float
)

x = x.float()
x = x.to('cuda')

# betas = torch.from_numpy(
#     get_beta_schedule(beta_start, beta_end, num_diffusion_timesteps)).float().to('cuda')

seq = range(1, num_diffusion_steps)
noise_arr = []
std_arr = []
with torch.no_grad():
    n = x.size(0)
    seq_next = [0] + list(seq[:-1])
    x0_preds, xs, x_last = [], [x], []
    xt_next = x
    times = list(zip(reversed(seq), reversed(seq_next)))

    for i,j in times:

        t = (torch.ones(n) * i)
        next_t = (torch.ones(n) * j)
        at = (1-betas).cumprod(dim=0)[t.long()].view(sampling_batch_size,1).to('cuda')
        at_next = (1-betas).cumprod(dim=0)[next_t.long()].view(sampling_batch_size,1).to('cuda')

        xt = xt_next
        xt[::,2] = 0

        t_emb = spm(t).to('cuda')
        xt_prime = rp_i(xt)
        et_prime = mlp(xt_prime, t_emb)
        et = rp_f(et_prime)

        x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
        x0_t[::,2] = 0

        c1 = eta * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()
        c2 = ((1 - at_next) - c1**2).sqrt()
        xt_next = at_next.sqrt() * x0_t + c1 * torch.normal(0, 1, size=x.size()).to('cuda') + c2 * et
        xt_next[::,2] = 0

In [None]:
xt_next

tensor([[-0.2393,  1.1753,  0.0000],
        [-0.2216,  1.1259,  0.0000],
        [ 0.1980,  0.6617,  0.0000],
        ...,
        [ 0.0513,  0.3594,  0.0000],
        [-0.7214,  0.8230,  0.0000],
        [ 0.0012,  0.5561,  0.0000]], device='cuda:0')