In [2]:
import torch
import torchvision
import matplotlib.pyplot as plt
from torch.optim import Adam
import torch.nn.functional as F
from new_data import HSIDataLoader,TestDS, TrainDS
import numpy as np
from plot import show_tensor_image,show_spectral_curve
from torch import nn
import math
%matplotlib inline

In [3]:
batch_size = 2048
dataloader = HSIDataLoader({})
train_loader,X, Y = dataloader.generate_torch_dataset()
newX = np.expand_dims(X.squeeze(),1)
newX = newX * 2 - 1
newY = Y
all_data_loader = torch.utils.data.DataLoader(dataset=TrainDS(newX,newY),
                                            batch_size=batch_size,
                                            shuffle=False,
                                            num_workers=0)
print(newX.shape, newY.shape)

[data] load data shape data=(145, 145, 200), label=(145, 145)
[data] data patches shape data=(21025, 1, 1, 200), label=(21025,)
------[data] after transpose train, test------
X.shape= (21025, 200, 1, 1)
Y.shape= (21025,)
(21025, 1, 200) (21025,)


In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

class Diffusion(object):
    def __init__(self) -> None:
        self.T = 100
        self.betas = self._linear_beta_schedule(timesteps=self.T)
        # Pre-calculate different terms for closed form
        self.alphas = 1. - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
        self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
        self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)


    def _linear_beta_schedule(self, timesteps, start=0.0001, end=0.02):
        return torch.linspace(start, end, timesteps)

    def _get_index_from_list(self, vals, t, x_shape):
        """ 
        Returns a specific index t of a passed list of values vals
        while considering the batch dimension.
        """
        batch_size = t.shape[0]
        out = vals.gather(-1, t.cpu())
        return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

    def forward_diffusion_sample(self, x_0, t, device="cpu"):
        """ 
        Takes an image and a timestep as input and 
        returns the noisy version of it
        """
        noise = torch.randn_like(x_0)
        sqrt_alphas_cumprod_t = self._get_index_from_list(self.sqrt_alphas_cumprod, t, x_0.shape)
        sqrt_one_minus_alphas_cumprod_t = self._get_index_from_list(
            self.sqrt_one_minus_alphas_cumprod, t, x_0.shape
        )
        # mean + variance
#         print('sqrt=',sqrt_alphas_cumprod_t.shape, 'x0=', x_0.shape)
        return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) \
        + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)


    def get_loss(self, model, x_0, t):
        x_noisy, noise = self.forward_diffusion_sample(x_0, t, device)
        noise_pred = model(x_noisy, t)
        return F.l1_loss(noise, noise_pred), x_noisy, noise, noise_pred


    @torch.no_grad()
    def sample_timestep(self, x, t, model):
        """
        Calls the model to predict the noise in the image and returns 
        the denoised image. 
        Applies noise to this image, if we are not in the last step yet.
        
        x is xt, t is timestamp
        return x_{t-1}
        """
        betas_t = self._get_index_from_list(self.betas, t, x.shape)
        sqrt_one_minus_alphas_cumprod_t = self._get_index_from_list(
            self.sqrt_one_minus_alphas_cumprod, t, x.shape
        )
        sqrt_recip_alphas_t = self._get_index_from_list(self.sqrt_recip_alphas, t, x.shape)

        # Call model (current image - noise prediction)
        model_mean = sqrt_recip_alphas_t * (
            x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
        )
        posterior_variance_t = self._get_index_from_list(self.posterior_variance, t, x.shape)

        if t == 0:
            return model_mean
        else:
            noise = torch.randn_like(x)
            return model_mean + torch.sqrt(posterior_variance_t) * noise 

    @torch.no_grad()
    def sample_plot_image(self, xt=None, num = 5):
        '''
        分别从纯noise和xt，逐步恢复信息 返回从t=
        '''
        stepsize = int(self.T / num)
        res = []
        # Sample noise
        if xt is None:
            img = torch.randn(xt.shape, device=device)
        else:
            img = xt
        for i in range(0, self.T)[::-1]:
            t = torch.full((1,), i, device=device, dtype=torch.long)
            img = self.sample_timestep(img, t)
            if i % stepsize == 0:
                res.append(img.detach().cpu())

        return res

In [5]:

class Block1D(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        self.time_mlp =  nn.Linear(time_emb_dim, out_ch)
        if up:
            self.conv1 = nn.Conv1d(2*in_ch, out_ch, 3, padding=1)
            self.transform = nn.ConvTranspose1d(out_ch, out_ch, 4, 2, 1)
        else:
            self.conv1 = nn.Conv1d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.Conv1d(out_ch, out_ch, 4, 2, 1)
        self.conv2 = nn.Conv1d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm1d(out_ch)
        self.bnorm2 = nn.BatchNorm1d(out_ch)
        self.relu  = nn.ReLU()
        
    def forward(self, x, t, ):
        # First Conv
        h = self.bnorm1(self.relu(self.conv1(x)))
        # Time embedding
        time_emb = self.relu(self.time_mlp(t))
        # Extend last 2 dimensions
        time_emb = time_emb[(..., ) + (None, ) * 1]
        # Add time channel
        h = h + time_emb
        # Second Conv
        h = self.bnorm2(self.relu(self.conv2(h)))
        # Down or Upsample
        return self.transform(h)



class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        # TODO: Double check the ordering here
        return embeddings


class SimpleUnet1D(nn.Module):
    """
    A simplified variant of the Unet architecture.
    """
    def __init__(self, _image_channels=1):
        super().__init__()
        image_channels = _image_channels 
        down_channels = (16,32,64,128)
        up_channels = (128,64,32,16)
        out_dim = 1 
        time_emb_dim = 32

        # Time embedding
        self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(time_emb_dim),
                nn.Linear(time_emb_dim, time_emb_dim),
                nn.ReLU()
            )
        
        # Initial projection
        self.conv0 = nn.Conv1d(image_channels, down_channels[0], 3, padding=1)

        # Downsample
        self.downs = nn.ModuleList([Block1D(down_channels[i], down_channels[i+1], \
                                    time_emb_dim) \
                    for i in range(len(down_channels)-1)])
        # Upsample
        self.ups = nn.ModuleList([Block1D(up_channels[i], up_channels[i+1], \
                                        time_emb_dim, up=True) \
                    for i in range(len(up_channels)-1)])

        self.output = nn.Conv1d(up_channels[-1], image_channels, out_dim)

    def forward(self, x, timestep):
        # x shape (batch, channel=1, spectral=200)
        # Embedd time
        t = self.time_mlp(timestep)
        # Initial conv
        x = self.conv0(x)
        # Unet
        residual_inputs = []
        for down in self.downs:
            x = down(x, t)
            residual_inputs.append(x)
            # print("down" , x.shape)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            # Add residual x as additional channels
            x = torch.cat((x, residual_x), dim=1)           
            x = up(x, t)
            # print("up",  x.shape)
        return self.output(x)


model = SimpleUnet1D()
diffusion = Diffusion()

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"
batch, channel, spe = newX.shape
model = SimpleUnet1D()
model.to(device)
optimizer = Adam(model.parameters(), lr=0.1)
epochs = 1000

for epoch in range(epochs):
    epoch_loss = 0
    epoch_size = 0
    for step, (batch, _) in enumerate(all_data_loader):

        batch = batch.to(device)
        optimizer.zero_grad()

        t = torch.randint(0, diffusion.T, (batch.shape[0],), device=device).long()
        loss, temp_xt, temp_noise, temp_noise_pred = diffusion.get_loss(model, batch, t)
        loss.backward()
        optimizer.step()
        epoch_loss += (loss.item() * batch.shape[0])
        epoch_size += batch.shape[0]
    
    print("epoch: %s, epoch_loss=%s, batch_loss=%s" %(epoch, epoch_loss/epoch_size, loss.item()))

epoch: 0, epoch_loss=0.7359533051219763, batch_loss=0.564034104347229
epoch: 1, epoch_loss=0.4432899054690008, batch_loss=0.39997878670692444
epoch: 2, epoch_loss=0.3501453744160293, batch_loss=0.33556509017944336
epoch: 3, epoch_loss=0.3296981728374604, batch_loss=0.3205973505973816
epoch: 4, epoch_loss=0.31696762806170053, batch_loss=0.32107067108154297
epoch: 5, epoch_loss=0.3044800476852127, batch_loss=0.30964457988739014
epoch: 6, epoch_loss=0.2987603702017867, batch_loss=0.3031260371208191
epoch: 7, epoch_loss=0.29068611843974357, batch_loss=0.3044062554836273
epoch: 8, epoch_loss=0.2818117147365166, batch_loss=0.3010668158531189
epoch: 9, epoch_loss=0.27191220063800337, batch_loss=0.27067452669143677
epoch: 10, epoch_loss=0.2655836780629742, batch_loss=0.269275426864624
epoch: 11, epoch_loss=0.26502322037211495, batch_loss=0.2610040009021759
epoch: 12, epoch_loss=0.25847614475152725, batch_loss=0.2592470049858093
epoch: 13, epoch_loss=0.25149683650552024, batch_loss=0.2561828792

epoch: 112, epoch_loss=0.18485075427434108, batch_loss=0.19096890091896057
epoch: 113, epoch_loss=0.1904562009607002, batch_loss=0.18054626882076263
epoch: 114, epoch_loss=0.1869781337091954, batch_loss=0.18500258028507233
epoch: 115, epoch_loss=0.18692127917524468, batch_loss=0.18551164865493774
epoch: 116, epoch_loss=0.18900034870460683, batch_loss=0.18624234199523926
epoch: 117, epoch_loss=0.18709418645015652, batch_loss=0.19142533838748932
epoch: 118, epoch_loss=0.18851656470159855, batch_loss=0.18392033874988556
epoch: 119, epoch_loss=0.1881585103138732, batch_loss=0.1773546040058136
epoch: 120, epoch_loss=0.1888755955934241, batch_loss=0.18048995733261108
epoch: 121, epoch_loss=0.18470919768818778, batch_loss=0.18197914958000183
epoch: 122, epoch_loss=0.18706170998628868, batch_loss=0.17157012224197388
epoch: 123, epoch_loss=0.18598840943131237, batch_loss=0.1809931993484497
epoch: 124, epoch_loss=0.18652278182055806, batch_loss=0.18483227491378784
epoch: 125, epoch_loss=0.187608

epoch: 223, epoch_loss=0.18392411589409877, batch_loss=0.16913942992687225
epoch: 224, epoch_loss=0.17938327929423623, batch_loss=0.18311633169651031
epoch: 225, epoch_loss=0.1811715045135455, batch_loss=0.17763866484165192
epoch: 226, epoch_loss=0.17999352505105004, batch_loss=0.1816737800836563
epoch: 227, epoch_loss=0.18333246580627388, batch_loss=0.16880565881729126
epoch: 228, epoch_loss=0.18369420453356788, batch_loss=0.17383702099323273
epoch: 229, epoch_loss=0.18049189151183886, batch_loss=0.16954274475574493
epoch: 230, epoch_loss=0.1800179308475694, batch_loss=0.16778185963630676
epoch: 231, epoch_loss=0.1789738156856453, batch_loss=0.17939336597919464
epoch: 232, epoch_loss=0.1819982147918445, batch_loss=0.19267849624156952
epoch: 233, epoch_loss=0.18421618525111577, batch_loss=0.17383646965026855
epoch: 234, epoch_loss=0.17960262229057045, batch_loss=0.1666814535856247
epoch: 235, epoch_loss=0.17632459425614366, batch_loss=0.16910946369171143
epoch: 236, epoch_loss=0.173782

epoch: 334, epoch_loss=0.17152352871353363, batch_loss=0.16658537089824677
epoch: 335, epoch_loss=0.1744082161015194, batch_loss=0.1692628711462021
epoch: 336, epoch_loss=0.1728023162594023, batch_loss=0.16620931029319763
epoch: 337, epoch_loss=0.1719625813470584, batch_loss=0.17115360498428345
epoch: 338, epoch_loss=0.1737906539822872, batch_loss=0.1643107533454895
epoch: 339, epoch_loss=0.17288045240426603, batch_loss=0.171861931681633
epoch: 340, epoch_loss=0.17385188559927356, batch_loss=0.17681889235973358
epoch: 341, epoch_loss=0.17655675980834984, batch_loss=0.17113061249256134
epoch: 342, epoch_loss=0.1747833845958415, batch_loss=0.16540104150772095
epoch: 343, epoch_loss=0.1736808116895265, batch_loss=0.16776660084724426
epoch: 344, epoch_loss=0.17325284831163856, batch_loss=0.168950617313385
epoch: 345, epoch_loss=0.17125540688545327, batch_loss=0.1593989133834839
epoch: 346, epoch_loss=0.1693572160392539, batch_loss=0.1643679440021515
epoch: 347, epoch_loss=0.170143316021005

epoch: 445, epoch_loss=0.17027027958071614, batch_loss=0.16492703557014465
epoch: 446, epoch_loss=0.16969337013331945, batch_loss=0.15517976880073547
epoch: 447, epoch_loss=0.16830842499146137, batch_loss=0.1625850647687912
epoch: 448, epoch_loss=0.16806474695732987, batch_loss=0.15983936190605164
epoch: 449, epoch_loss=0.17019949227961292, batch_loss=0.1674935519695282
epoch: 450, epoch_loss=0.1703394568462576, batch_loss=0.17214488983154297
epoch: 451, epoch_loss=0.1702620488894255, batch_loss=0.15805640816688538
epoch: 452, epoch_loss=0.17068066493863868, batch_loss=0.15792657434940338
epoch: 453, epoch_loss=0.17186297608896045, batch_loss=0.16146957874298096
epoch: 454, epoch_loss=0.17061878225253396, batch_loss=0.16053591668605804
epoch: 455, epoch_loss=0.17345391621558476, batch_loss=0.16557835042476654
epoch: 456, epoch_loss=0.16797143140885265, batch_loss=0.16073386371135712
epoch: 457, epoch_loss=0.16998194441075273, batch_loss=0.1639336347579956
epoch: 458, epoch_loss=0.17176

KeyboardInterrupt: 