In [1]:
a = 'mohamed'

In [4]:
def multiply_by_n(i, n):
    return i*n

In [7]:
list(map(lambda x: x * 5, [1,2,3,4]))

[5, 10, 15, 20]

In [8]:
def multiply_list_by_n(numbers, n):
    return list(map(lambda x, y: x * y, numbers))


In [12]:
func  = lambda y: lambda y: y

In [17]:
a = [1,2,3,4]
b = [5,6,7,8]
ab = []
list(map(lambda x, y: x * y, a, b))


[5, 12, 21, 32]

In [15]:
[1,1,1,1
 2,2,2,2
 3,3,3,3
 4,4,4,4]

[5, 12, 21, 32]

In [1]:
import torch
import torch.nn as nn

# A function to register hooks and capture layer-wise input/output shapes
def register_hooks(model):
    shapes = {}

    def hook(module, input, output):
        module_name = module.__class__.__name__
        shapes[module_name] = {'input_shape': tuple(input[0].shape), 'output_shape': tuple(output.shape)}
        print(f"{module_name}: Input shape: {input[0].shape}, Output shape: {output.shape}")

    # Register hook to all modules
    for name, layer in model.named_modules():
        if len(list(layer.children())) == 0:  # Register hooks only for the final layers, not containers
            layer.register_forward_hook(hook)

    return shapes

# Define any model or use any custom architecture
class ExampleModel(nn.Module):
    def __init__(self):
        super(ExampleModel, self).__init__()
        self.layer1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.activation = nn.ReLU()
        self.layer2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        x = self.layer1(x)
        x = self.activation(x)
        x = self.layer2(x)
        return x

# Instantiate the model and register hooks
model = ExampleModel()
shapes = register_hooks(model)

# Create an example input and pass through the model
input_tensor = torch.randn(1, 3, 64, 64)
output = model(input_tensor)


Conv2d: Input shape: torch.Size([1, 3, 64, 64]), Output shape: torch.Size([1, 64, 64, 64])
ReLU: Input shape: torch.Size([1, 64, 64, 64]), Output shape: torch.Size([1, 64, 64, 64])
Conv2d: Input shape: torch.Size([1, 64, 64, 64]), Output shape: torch.Size([1, 128, 32, 32])


: 

In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from abc import ABC , abstractmethod
import math
import logging
from einops import rearrange

# Set up logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

#######################################################################
################################# ABSTRACT BLOCKS #####################
#######################################################################
class XCBlock(nn.Module):
    @abstractmethod
    def forward(self, x):
        pass

class XBlock(nn.Module):
    @abstractmethod
    def forward(self, x):
        pass

class XTCBlock(nn.Module):
    @abstractmethod
    def forward(self, x, t, c):
        pass

class TBlock(nn.Module):
    @abstractmethod
    def forward(self, t):
        pass


def zero_module(module):
    for p in module.parameters():
        p.detach().zero_()
    return module


class XTBlock(nn.Module):
    @abstractmethod
    def forward(self, x, t):
        pass


class Downsample(XBlock):
    def __init__(self, scale_factor, mode='nearest'):
        super().__init__()
        self.scale_factor = scale_factor
        self.mode = mode

    def forward(self, x):
        return F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, recompute_scale_factor=True)



class Upsample(XBlock):
    def __init__(self, scale_factor, mode='nearest'):
        super().__init__()
        
        self.scale_factor = scale_factor
        self.mode = mode

    def forward(self, x):
        return F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode)


######################################################################
########################### TIME EMBEDDING BLOCKS ##########################
#####################################################################
class SinusoidalTimeEmbedder(TBlock):
    def __init__(self, base_channels, max_period = 10000):
        super().__init__()
        self.base_channels = base_channels
        self.max_period = max_period
        self.half = base_channels // 2
        self.freqs = torch.exp(-math.log(max_period) * torch.arange(self.half, dtype=torch.float32) / self.half)

    def forward(self, t):
        freqs = self.freqs.to(t.device)
        args = t[:, None].float() * freqs
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if self.base_channels % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding


class LearnableTimeEmbedder(TBlock):
    def __init__(self, base_channels, time_embedding_dim):
        super().__init__()
        
        self.time_embed = nn.Sequential(
            nn.Linear(base_channels, time_embedding_dim),
            nn.SiLU(),
            nn.Linear(time_embedding_dim, time_embedding_dim),
        )

    def forward(self, t):
        t = self.time_embed(t)
        return t


class TimeEmbedder(TBlock):
    def __init__(self, base_channels, time_embedding_dim):
        super().__init__()
        self.time_embed = nn.Sequential(
            SinusoidalTimeEmbedder(base_channels),
            LearnableTimeEmbedder(base_channels, time_embedding_dim)
        )
    def forward(self, t):
        return self.time_embed(t)



class TimeInjector(XTBlock):
    def __init__(self, time_embedding_dim, in_channels):
        super().__init__()

        self.time_embed = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_embedding_dim, in_channels)
        )
        self.proj_out = nn.Sequential(
                        nn.GroupNorm(32, in_channels),
                        nn.SiLU(),
                        zero_module(nn.Conv2d(in_channels, in_channels, 3, padding=1))
        )
    
    def forward(self, x, t):
        t = self.time_embed(t).type(x.dtype)
        while len(t.shape) < len(x.shape):
            t = t[..., None]
        return self.proj_out(x + t)
    
############################################################
################### CONVOLUTIONAL BLOCKS ###################
#############################################################
class InConvBlock(XBlock):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.in_layers = nn.Sequential(
            nn.GroupNorm(32, in_channels),
            nn.SiLU(),
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
        )
    def forward(self, x):
        return self.in_layers(x)
    

class OutConvBlock(XBlock):
    def __init__(self, out_channels):
        super().__init__()
        
        self.out_layers = nn.Sequential(
            nn.GroupNorm(32, out_channels),
            nn.SiLU(),
            zero_module(nn.Conv2d(out_channels, out_channels, 3, padding=1))
        )
    
    def forward(self, x):
        return self.out_layers(x)
    


class ResBlock(XTBlock):
    def __init__(self, in_channels, time_embedding_dim):
        super().__init__()
        self.inconv = InConvBlock(in_channels, in_channels)
        self.injector = TimeInjector(time_embedding_dim, in_channels)
    def forward(self, x, t):
        h = self.inconv(x)
        h = self.injector(h, t)
        return x + h   


##############################################################
############################ MAGIC ###########################
##############################################################

class Connection:
    def __init__(self, start_block, target_block, operation):
        self.target_block = target_block
        self.start_block = start_block
        self.operation = operation
        self.collected_tensor = None

    def is_target_block(self, name):
        return name == self.target_block

    def is_start_block(self, name):
        return name == self.start_block

    def excute_operation(self, x):
        return self.operation(x, self.collected_tensor)

    def collect(self, x):
        self.collected_tensor = x

    def __repr__(self):
        return (f"Connection(start_block={self.start_block}, "
                f"target_block={self.target_block}, "
                f"operation={self.operation.__name__ if hasattr(self.operation, '__name__') else str(self.operation)}, "
                f"collected_tensor={'Set' if self.collected_tensor is not None else 'None'})")



class Router(nn.Sequential):
    def __init__(self, connections, *args):
        super(Router, self).__init__(*args)
        
        self.connections = connections

    def forward(self, x, t, c=None):

        for block_name, block in self.named_children():
            logger.debug(f'{block.__class__.__name__}')
            for connection in self.connections:
                if connection.is_target_block(block_name):
                    x = connection.excute_operation(x)
                    logger.debug(f'{x.shape}')

            logger.debug(f'{x.shape}')
            if isinstance(block, XTBlock):
                x = block(x, t)
            elif isinstance(block, TBlock):
                x = block(t)
            elif isinstance(block, XBlock):
                x = block(x)
            elif isinstance(block, XTCBlock):
                x = block(x, t, c)
            elif isinstance(block, XCBlock):
                x = block(x, c)            
            logger.debug(f'{x.shape}')

            for connection in self.connections:
                if connection.is_start_block(block_name):
                    connection.collect(x)
        return x
    
class ChannelChanger(XBlock):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.layer = nn.Conv2d(in_channels, out_channels, 3, padding=1)
    def forward(self, x):
        return self.layer(x)


def connection_finder(unet, name):
    names = []
    left = []
    right = []
    for i, block in enumerate(unet):
        if block.__class__.__name__ == name:
            names.append(str(i))
    
    while len(names) != 1:
        left.append(names.pop(0))
        right.append(names.pop())

    connections = [Connection(first, second, concat) for first, second in zip(left, right)]
    logger.debug(f'left: {left}')
    logger.debug(f'right: {right}')

    return connections



def concat(x, y):
    return torch.cat([x, y], dim=1)

################################################################
################### ATTENTION BLOCKS ###########################
################################################################
                    
class CrossAttention(XCBlock):
    def __init__(self, d_query, d_context, n_heads, head_dim):
        super().__init__()
        
        inner_dim = n_heads*head_dim
        self.n_heads = n_heads
        self.scale = head_dim**-0.5
        
        self.Q = nn.Linear(d_query, inner_dim, bias=False)
        self.K = nn.Linear(d_context, inner_dim, bias=False)
        self.V = nn.Linear(d_context, inner_dim, bias=False)
        
        self.proj_out = nn.Linear(inner_dim, d_query)
        self.norm = nn.LayerNorm(d_query)

    def forward(self, x, context):
        h = self.norm(x)
        
        Q = self.Q(h)
        K = self.K(context)
        V = self.V(context)

        Q, K, V = map(lambda t: rearrange( t, 'b n (h d) -> (b h) n d', h=self.n_heads), (Q, K, V))

        attention = torch.einsum('b i d, b j d -> b i j', Q, K)
        attention = attention.softmax(dim=-1)
        attention = attention*self.scale

        values = torch.einsum('b i j, b j d -> b i d', attention, V)
        values = rearrange(values, '(b h) n d -> b n (h d)', h=self.n_heads)
        values = self.proj_out(values)
        return x + values


class SelfAttention(XBlock):
    def __init__(self, d, n_heads, head_dim):
        super().__init__()
        
        inner_dim = n_heads*head_dim
        self.n_heads = n_heads
        self.scale  = head_dim**-0.5
        
        self.Q = nn.Linear(d, inner_dim, bias=False)
        self.K = nn.Linear(d, inner_dim, bias=False)
        self.V = nn.Linear(d, inner_dim, bias=False)
        
        self.proj_out = nn.Linear(inner_dim, d)
        self.norm = nn.LayerNorm(d)

    def forward(self, x):
        h = self.norm(x)
        
        Q = self.Q(h)
        K = self.K(h)
        V = self.V(h)

        Q, K, V = map(lambda t: rearrange( t, 'b n (h d) -> (b h) n d', h=self.n_heads), (Q, K, V))

        attention = torch.einsum('b i d, b j d -> b i j', Q, K)
        attention = attention.softmax(dim=-1)
        attention = attention*self.scale

        values = torch.einsum('b i j, b j d -> b i d', attention, V)
        values = rearrange(values, '(b h) n d -> b n (h d)', h=self.n_heads)
        values = self.proj_out(values)
        return x + values 

class GEGLU(XBlock):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        
        self.proj = nn.Linear(dim_in, dim_out * 2)

    def forward(self, x):
        
        x, gate = self.proj(x).chunk(2, dim=-1)
        
        return x * F.gelu(gate)
    

class FeedForwardGEGLU(XBlock):
    def __init__(self, d_query, dropout=0.):
        super().__init__()
        
        self.net = nn.Sequential(nn.LayerNorm(d_query),
            GEGLU(d_query, d_query*4),
            nn.Dropout(dropout),
            nn.Linear(d_query*4, d_query)
        )

    def forward(self, x):
        return self.net(x) + x
    
class FeedForwardGLU(XBlock):
    def __init__(self, d_query, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(nn.LayerNorm(d_query),
            nn.Linear(d_query, 4*d_query),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(4*d_query, d_query)
        )
    def forward(self, x):
        return self.net(x) + x
    

class Adapter(XBlock):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return rearrange(x , 'b c h w -> b (h w) c')
    
    
class UnetBlock(XTCBlock):
    def __init__(self,
                  in_channels,
                  context_dim=128,
                  time_embedding_dim=128,
                  n_heads=8,
                  head_dim=32
    ):
        super().__init__()
        self.block = Router([],
            *[
            ResBlock(in_channels, time_embedding_dim),
            Adapter(),
            SelfAttention(in_channels, n_heads, head_dim),
            CrossAttention(in_channels, context_dim, n_heads, head_dim),
            FeedForwardGEGLU(in_channels)   
            ]
        )
    def forward(self, x, t, context):
        _, _, h, w = x.shape
        x = self.block(x, t, context)
        return rearrange(x , 'b (h w) c -> b c h w', h=h, w=w)

#######################################################
####################### UNET ##########################
#######################################################

class ResChain(XTBlock):
    def __init__(self, in_channels, num_resblocks, time_embedding_dim=128):
        super().__init__()
        self.resblocks = nn.ModuleList(
            [
                ResBlock(in_channels, time_embedding_dim) for _ in range(num_resblocks)
            ]
                        
        )
  

    def forward(self, x, t):
        for block in self.resblocks:
            x = block(x, t)
        return x
    
class Unet(nn.Module):
    def __init__(self, channels=[64, 128, 256, 512], num_resblocks=5):
        super().__init__()        
        self.in_proj = nn.Conv2d(1, channels[0], 3, padding=1)
        self.out_proj = nn.Conv2d(channels[0], 1, 3, padding=1)

        self.time_embedder = TimeEmbedder(channels[0], time_embedding_dim=128)
        
        
        self.down = []
        for i in range(len(channels)-1):
            self.down.append(ResChain(channels[i], num_resblocks))
            self.down.append(Downsample(0.5, 'nearest'))
            self.down.append(ChannelChanger(channels[i], channels[i+1]))
        
        self.mid = [ResChain(channels[-1], num_resblocks)]

        channels_reversed = list(reversed(channels))
        self.up = []
        for i in range(len(channels_reversed)-1):
            self.up.append(Upsample(2, 'nearest'))
            self.up.append(ChannelChanger(channels_reversed[i], channels_reversed[i+1]))
            self.up.append(ResChain(2*channels_reversed[i+1], num_resblocks))
            self.up.append(ChannelChanger(2*channels_reversed[i+1], channels_reversed[i+1]))

        unet = self.down + self.mid + self.up
        connections = connection_finder(unet, 'ResChain')
        self.unet = Router(connections, *unet)




    def forward(self, x, t):
        t = self.time_embedder(t)
        x = self.in_proj(x)
        x = self.unet(x, t)
        return self.out_proj(x) 







In [22]:
model = Unet()

2024-11-14 17:43:09,149 - DEBUG - left: ['0', '3', '6']
2024-11-14 17:43:09,149 - DEBUG - right: ['20', '16', '12']


In [56]:
att1 = CrossAttention(512, 128, 8, 32)
att2 = SelfAttention(512, 8, 32)
trans = TransformerBlock(512, 128, 8, 32)
x = torch.randn(4, 8, 512)
context = torch.randn((4, 4, 128))
print(att1(x, context).shape)
print(att2(x).shape)
print(trans(x, context).shape)


torch.Size([4, 8, 512])
torch.Size([4, 8, 512])
torch.Size([4, 8, 512])


In [88]:
in_channels = 512
n_heads = 8
head_dim = 32
context_dim = 128
time_embedding_dim = 128


x = torch.randn(4, 512, 4, 4)
context = torch.randn((4, 4, 128))
embed = TimeEmbedder(base_channels=64, time_embedding_dim=128)
t = embed(torch.tensor(100)[None])

block = UnetBlock(in_channels=512)

block(x, t, context).shape

torch.Size([4, 512, 4, 4])

In [30]:
x = torch.randn(1, 1, 128, 128)
t = torch.randint(1000 , size=(1,))
model(x, t).shape

2024-11-14 17:45:54,676 - DEBUG - ResChain
2024-11-14 17:45:54,676 - DEBUG - torch.Size([1, 64, 128, 128])


2024-11-14 17:45:54,790 - DEBUG - torch.Size([1, 64, 128, 128])
2024-11-14 17:45:54,790 - DEBUG - Downsample
2024-11-14 17:45:54,790 - DEBUG - torch.Size([1, 64, 128, 128])
2024-11-14 17:45:54,790 - DEBUG - torch.Size([1, 64, 64, 64])
2024-11-14 17:45:54,790 - DEBUG - ChannelChanger
2024-11-14 17:45:54,790 - DEBUG - torch.Size([1, 64, 64, 64])
2024-11-14 17:45:54,790 - DEBUG - torch.Size([1, 128, 64, 64])
2024-11-14 17:45:54,799 - DEBUG - ResChain
2024-11-14 17:45:54,799 - DEBUG - torch.Size([1, 128, 64, 64])
2024-11-14 17:45:54,846 - DEBUG - torch.Size([1, 128, 64, 64])
2024-11-14 17:45:54,846 - DEBUG - Downsample
2024-11-14 17:45:54,846 - DEBUG - torch.Size([1, 128, 64, 64])
2024-11-14 17:45:54,846 - DEBUG - torch.Size([1, 128, 32, 32])
2024-11-14 17:45:54,846 - DEBUG - ChannelChanger
2024-11-14 17:45:54,846 - DEBUG - torch.Size([1, 128, 32, 32])
2024-11-14 17:45:54,860 - DEBUG - torch.Size([1, 256, 32, 32])
2024-11-14 17:45:54,861 - DEBUG - ResChain
2024-11-14 17:45:54,861 - DEBUG -

torch.Size([1])


2024-11-14 17:45:54,901 - DEBUG - torch.Size([1, 256, 32, 32])
2024-11-14 17:45:54,901 - DEBUG - Downsample
2024-11-14 17:45:54,901 - DEBUG - torch.Size([1, 256, 32, 32])
2024-11-14 17:45:54,909 - DEBUG - torch.Size([1, 256, 16, 16])
2024-11-14 17:45:54,910 - DEBUG - ChannelChanger
2024-11-14 17:45:54,911 - DEBUG - torch.Size([1, 256, 16, 16])
2024-11-14 17:45:54,915 - DEBUG - torch.Size([1, 512, 16, 16])
2024-11-14 17:45:54,915 - DEBUG - ResChain
2024-11-14 17:45:54,915 - DEBUG - torch.Size([1, 512, 16, 16])
2024-11-14 17:45:54,949 - DEBUG - torch.Size([1, 512, 16, 16])
2024-11-14 17:45:54,949 - DEBUG - Upsample
2024-11-14 17:45:54,949 - DEBUG - torch.Size([1, 512, 16, 16])
2024-11-14 17:45:54,963 - DEBUG - torch.Size([1, 512, 32, 32])
2024-11-14 17:45:54,964 - DEBUG - ChannelChanger
2024-11-14 17:45:54,964 - DEBUG - torch.Size([1, 512, 32, 32])
2024-11-14 17:45:54,971 - DEBUG - torch.Size([1, 256, 32, 32])
2024-11-14 17:45:54,971 - DEBUG - ResChain
2024-11-14 17:45:54,974 - DEBUG - t

torch.Size([1, 1, 128, 128])

In [11]:
import torch
import torch.nn as nn
import numpy as np
from abc import ABC, abstractmethod
from functools import partial
from inspect import isfunction
from refactor.schedules import extract_into_tensor, exists, default, LinearSchedule, BetaSchedule
import matplotlib.pyplot as plt
from data.LoadData import load_single_aws_zarr, AWS_ZARR_ROOT
to_torch = partial(torch.tensor, dtype=torch.float32)



import os
import s3fs
import zarr
from typing import Union
import dask.array as da

AWS_ZARR_ROOT = (
    "s3://gov-nasa-hdrl-data1/contrib/fdl-sdoml/fdl-sdoml-v2/sdomlv2.zarr/"
)


def s3_connection(path_to_zarr: os.path) -> s3fs.S3Map:
    """
    Instantiate connection to aws for a given path `path_to_zarr`
    """
    return s3fs.S3Map(
        root=path_to_zarr,
        s3=s3fs.S3FileSystem(anon=True),
        # anonymous access requires no credentials
        check=False,
    )


def load_single_aws_zarr(
    path_to_zarr: os.path,
    cache_max_single_size: int = None,
    wavelength='171A',
) -> Union[zarr.Array, zarr.Group]:
    """
    load zarr from s3 using LRU cache
    """
    root = zarr.open(
            zarr.LRUStoreCache(
                store=s3_connection(path_to_zarr),
                max_size=cache_max_single_size,
            ),
            mode="r",
         )
    data = root[wavelength]
    data = da.from_array(data)

    return data


data = load_single_aws_zarr(
        path_to_zarr=AWS_ZARR_ROOT + str(2015),
        wavelength='171A'
    )
class Diffusion(nn.Module):
    def __init__(self,
                schedule: BetaSchedule,
                parameterization: str ='eps',
                v_posterior: float = 0.0
                ):
        
        super().__init__()
        self.parameterization = parameterization
        self.v_posterior = v_posterior
        self.schedule = schedule
        self.register_schedule()

    def register_schedule(self):
        
        betas =  self.schedule.betas().numpy()
        alphas = 1. - betas
        alphas_cumprod = np.cumprod(alphas, axis=0)
        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
        assert alphas_cumprod.shape[0] == self.schedule.timesteps
        
        self.register_buffer('betas', to_torch(betas))
        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
        self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))

        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))

        posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + self.v_posterior * betas
        
        self.register_buffer('posterior_variance', to_torch(posterior_variance))
        self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
        self.register_buffer('posterior_mean_coef1', to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
        self.register_buffer('posterior_mean_coef2', to_torch((1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))


        if self.parameterization == "eps":
            lvlb_weights = self.betas ** 2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
        elif self.parameterization == "x0":
            lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
        else:
            raise NotImplementedError("mu not supported")
        
        # TODO how to choose this term
        lvlb_weights[0] = lvlb_weights[1]
        self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
        assert not torch.isnan(self.lvlb_weights).all()
    
    def add_noise(self, x0, t):
        noise = torch.randn_like(x0)
        
        sqrt_alphas_cumprod = extract_into_tensor(self.sqrt_alphas_cumprod, t, x0.shape)
        sqrt_one_minus_alphas_cumprod = extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x0.shape)
        
        xt = sqrt_alphas_cumprod * x0 + sqrt_one_minus_alphas_cumprod * noise
        return xt, noise



if __name__ == '__main__':
    schedule = LinearSchedule()
    diffusion = Diffusion(schedule)
    x = data[0].compute()
    #print(data[0].compute().shape)
    #x = torch.randn(512, 512)
    xt, noise = diffusion.add_noise(x, torch.tensor(1)[None])
    plt.imshow(xt.cpu().numpy(), cmap='afmhot')

KeyboardInterrupt: 

In [9]:
data = load_single_aws_zarr(path_to_zarr=AWS_ZARR_ROOT + str(2015), wavelength='171A')
#data[0].compute()


In [7]:
import torch
torch.nn.MSELoss()(torch.randn(1, 1, 4, 512, 512), torch.randn(1, 1, 4, 512, 512))

tensor(1.9981)

In [125]:
unet = down_blocks + middle_blocks + up_blocks

In [13]:
connections = [Connection('0','12', concat),
               Connection('2','10', concat),
               Connection('4','8', concat)
               ]

In [1]:
from torchvision import models
import torch
import torch.nn as nn
vgg = models.vgg16(pretrained=True)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to C:\Users\mhesh/.cache\torch\hub\checkpoints\vgg16-397923af.pth
100%|██████████| 528M/528M [00:55<00:00, 9.92MB/s] 
