In [1]:
# model test 用

import copy
from copy import deepcopy
import math

import torch
from torch import nn 

from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from torch.autograd import Variable

In [2]:
# my vit

import copy
import math

import torch
from torch import nn 

from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from torch.autograd import Variable

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 12, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        # print("Attention_out:", self.to_out(out).shape)
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, fc_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, fc_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):
    def __init__(self, in_size, num_obs, out_size, dim, depth, heads, fc_dim, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        position = PositionalEncoding(dim, dropout)
        c = copy.deepcopy
        self.input_embedding = nn.Sequential(LinearEmbedding(in_size,dim), c(position))
        
        self.pos_embedding = nn.Parameter(torch.randn(1, num_obs + 12, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 12, dim))
        
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, fc_dim, dropout)

        #self.pool = pool
        self.generator = Generator(dim, out_size)

    def forward(self, input):
        # print("input:", input.shape)
        x = self.input_embedding(input)
        b, n, _ = x.shape
        # print("in_en:", x.shape)

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 12)]
        x = self.dropout(x)

        x = self.transformer(x)

        # print("out_TF:", x.shape)

        #x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.generator(x)
        return x

class PositionalEncoding(nn.Module):
    """
    Implement the PE function.
    """

    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # print("pos_x:", x.shape)
        # print("pos_e:", Variable(self.pe[:, :x.size(1)], requires_grad=False).shape)
        x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False)
        return self.dropout(x)

class LinearEmbedding(nn.Module):
    def __init__(self, inp_size,d_model):
        super(LinearEmbedding, self).__init__()
        # lut => lookup table
        self.lut = nn.Linear(inp_size, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)



class Generator(nn.Module):
    "Define standard linear + softmax generation step."

    def __init__(self, dim, out_size):
        super(Generator, self).__init__()
        self.proj = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, out_size)
        )

    def forward(self, x):
        # print("generator:", x.shape)
        return self.proj(x)

In [None]:
# original vit

import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 20, dim_head = 1600, dropout = 0.1):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, fc_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, fc_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):
    def __init__(self, num_pred, dim, depth, heads = 20, fc_dim, dim_head = 1600, dropout = 0.1, emb_dropout = 0.1):
        super().__init__()
        
        x_dim = 40
        y_dim = 40
        z_dim = 20
        
        patch_dim = x_dim * y_dim
        
        self.to_patch_embedding = nn.Sequential(
            reshape = Rearrange('b x y z data -> b (x y) z data', x = x_dim, y_dim = 40, z_dim = 20)
            nn.Linear(patch_dim, dim),
        )
        
        self.num_pred = num_pred

        self.pos_embedding = nn.Parameter(torch.randn(1, patch_dim + num_pred, dim))
        self.pred_token = nn.Parameter(torch.randn(1, num_pred, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, fc_dim, dropout)

        self.generator = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_pred)
        )

    def forward(self, data):
        x = self.to_patch_embedding(data)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((pred_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + num_pred)]
        x = self.dropout(x)

        x = self.transformer(x)

        return self.generator(x)

In [5]:
# Mazda_? xy平面をパッチとして入力2

inp = torch.rand(1,40,40,20,12)
print("inp:", inp.shape)

x_dim = 40
y_dim = 40
z_dim = 20
data_dim = 12

patch_dim = x_dim * y_dim * data_dim

patch_reshape = Rearrange('b x y z data -> b z (x y data)', x = x_dim, y = y_dim, z = z_dim, data = data_dim)
x = patch_reshape(inp)
print("patch_reshape:", x.shape)

dim = 512

patch_L = nn.Linear(patch_dim, dim)
x = patch_L(x)
print("patch_emb:", x.shape)

#for num_z in range(z_dim):
#    if num_z == 0:
#        patch_x = patch_L(x[:,:,num_z,:])
#    else:
#        patch_x = torch.cat((patch_x, patch_L(x[:,:,num_z,:])), dim=1)
#print("patch_Linear:", patch_x.shape)

b, n, _ = x.shape
num_pred = 101

pred_token = nn.Parameter(torch.randn(1, num_pred, dim))
print("pred_token:", pred_token.shape)

pred_tokens = repeat(pred_token, '() n d -> b n d', b = b)
print("pred_tokens:", pred_tokens.shape)

x = torch.cat((pred_tokens, x), dim=1)
print("cat(pred_tokens,x):", x.shape)

pos_emb = nn.Parameter(torch.randn(1, patch_dim + num_pred, dim))
x += pos_emb[:,:(n+num_pred)]
print("pos_emb:", x.shape)

dropout_rate = 0.1
pos_dropout = nn.Dropout(dropout_rate)
x = pos_dropout(x)
print("pos_dropout:", x.shape)

attn_norm = nn.LayerNorm(dim)
attn_in = attn_norm(x)
print("PreNorm_Attn", attn_in.shape)

dim_head = 512
heads = 20
inner_dim = dim_head * heads

scale = dim_head ** -0.5

to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
print("qkv:", to_qkv(x).shape)
qkv = to_qkv(x).chunk(3, dim = -1)
print("qkv[0]:", qkv[0].shape)
print("qkv[1]:", qkv[1].shape)
print("qkv[2]:", qkv[2].shape)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = heads), qkv)
print("q:", q.shape)
print("k:", k.shape)
print("v:", v.shape)

dots = torch.matmul(q, k.transpose(-1, -2)) * scale
print("dots:", dots.shape)

attend = nn.Softmax(dim = -1)
attn = attend(dots)
print("attn:", attn.shape)

attn_score = torch.matmul(attn, v)
print("attn_score:", attn_score.shape)

attn_score = rearrange(attn_score, 'b h n d -> b n (h d)')
print("attn_reshape", attn_score.shape)

attn_L = nn.Linear(inner_dim, dim)
attn_out = attn_L(attn_score)
print("attn_Linear:", attn_out.shape)

attn_dropout = nn.Dropout(dropout_rate)
attn_out = attn_dropout(attn_out)
print("attn_out:", attn_out.shape)

x = attn_out + x
print("x_attn_out:", x.shape)

ff_norm = nn.LayerNorm(dim)
ff_in = ff_norm(x)
print("PreNorm_Attn", ff_in.shape)

hidden_dim = 2048
ff_L1 = nn.Linear(dim, hidden_dim)
ff = ff_L1(ff_in)
print("ff_Linear1:", ff.shape)

GELU = nn.GELU()
ff = GELU(ff)
print("ff_GELU:", ff.shape)

ff_dropout1 = nn.Dropout(dropout_rate)
ff = ff_dropout1(ff)
print("ff_dropout1", ff.shape)

ff_L2 = nn.Linear(hidden_dim, dim)
ff = ff_L2(ff)
print("ff_Linear2:", ff.shape)

ff_dropout2 = nn.Dropout(dropout_rate)
ff = ff_dropout2(ff)
print("ff_dropout2", ff.shape)

x = ff + x
print("x_ff_out:", x.shape)

x = x.mean(dim=1)
print("pooling:", x.shape)

to_latent = nn.Identity()
x = to_latent(x)
print("latent:", x.shape)

generator_norm = nn.LayerNorm(dim)
x = generator_norm(x)
print("generator_norm", x.shape)

dim_spec = 7
spec = torch.randn(1, dim_spec)
print("spec:", spec.shape)

spec_tokens = repeat(spec, '() d -> b d', b = b)
print("spec_tokens:", spec_tokens.shape)

x = torch.cat((spec_tokens, x), dim=1)
print("cat(spec_tokens,x):", x.shape)

generator = nn.Linear(dim+dim_spec, num_pred)
pred = generator(x)
print("model_out:", pred.shape)

inp: torch.Size([1, 40, 40, 20, 12])
patch_reshape: torch.Size([1, 20, 19200])
patch_emb: torch.Size([1, 20, 512])
pred_token: torch.Size([1, 101, 512])
pred_tokens: torch.Size([1, 101, 512])
cat(pred_tokens,x): torch.Size([1, 121, 512])
pos_emb: torch.Size([1, 121, 512])
pos_dropout: torch.Size([1, 121, 512])
PreNorm_Attn torch.Size([1, 121, 512])
qkv: torch.Size([1, 121, 30720])
qkv[0]: torch.Size([1, 121, 10240])
qkv[1]: torch.Size([1, 121, 10240])
qkv[2]: torch.Size([1, 121, 10240])
q: torch.Size([1, 20, 121, 512])
k: torch.Size([1, 20, 121, 512])
v: torch.Size([1, 20, 121, 512])
dots: torch.Size([1, 20, 121, 121])
attn: torch.Size([1, 20, 121, 121])
attn_score: torch.Size([1, 20, 121, 512])
attn_reshape torch.Size([1, 121, 10240])
attn_Linear: torch.Size([1, 121, 512])
attn_out: torch.Size([1, 121, 512])
x_attn_out: torch.Size([1, 121, 512])
PreNorm_Attn torch.Size([1, 121, 512])
ff_Linear1: torch.Size([1, 121, 2048])
ff_GELU: torch.Size([1, 121, 2048])
ff_dropout1 torch.Size([1, 

In [None]:
# Mazda_? 変数を1次元に並べる

inp = torch.rand(1,40,40,20,12)
print("inp:", inp.shape)

x_dim = 40
y_dim = 40
z_dim = 20
data_dim = 12

patch_dim = x_dim * y_dim * z_dim

patch_reshape = Rearrange('b x y z data -> b (x y z) data', x = x_dim, y = y_dim, z = z_dim, data = data_dim)
x = patch_reshape(inp)
print("patch_reshape:", x.shape)

dim = 512

patch_L = nn.Linear(data_dim, dim)
x = patch_L(x)
print("patch_Linear:", x.shape)

b, n, _ = x.shape
num_pred = 101

pred_token = nn.Parameter(torch.randn(1, num_pred, dim))
print("pred_token:", pred_token.shape)

pred_tokens = repeat(pred_token, '() n d -> b n d', b = b)
print("pred_tokens:", pred_tokens.shape)

x = torch.cat((pred_tokens, x), dim=1)
print("cat(pred_tokens,x):", x.shape)

pos_emb = nn.Parameter(torch.randn(1, patch_dim + num_pred, dim))
x += pos_emb[:,:(n+num_pred)]
print("pos_emb:", x.shape)

dropout_rate = 0.1
pos_dropout = nn.Dropout(dropout_rate)
x = pos_dropout(x)
print("pos_dropout:", x.shape)

attn_norm = nn.LayerNorm(dim)
attn_in = attn_norm(x)
print("PreNorm_Attn", attn_in.shape)

dim_head = 512
heads = 20
inner_dim = dim_head * heads

scale = dim_head ** -0.5

to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
print("qkv:", to_qkv(x).shape)
qkv = to_qkv(x).chunk(3, dim = -1)
print("qkv[0]:", qkv[0].shape)
print("qkv[1]:", qkv[1].shape)
print("qkv[2]:", qkv[2].shape)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = heads), qkv)
print("q:", q.shape)
print("k:", k.shape)
print("v:", v.shape)

dots = torch.matmul(q, k.transpose(-1, -2)) * scale
print("dots:", dots.shape)

attend = nn.Softmax(dim = -1)
attn = attend(dots)
print("attn:", attn.shape)

attn_score = torch.matmul(attn, v)
print("attn_score:", attn_score.shape)

attn_score = rearrange(attn_score, 'b h n d -> b n (h d)')
print("attn_reshape", attn_score.shape)

attn_L = nn.Linear(inner_dim, dim)
attn_out = attn_L(attn_score)
print("attn_Linear:", attn_out.shape)

attn_dropout = nn.Dropout(dropout_rate)
attn_out = attn_dropout(attn_out)
print("attn_out:", attn_out.shape)

x = attn_out + x
print("x_attn_out:", x.shape)

ff_norm = nn.LayerNorm(dim)
ff_in = ff_norm(x)
print("PreNorm_Attn", ff_in.shape)

hidden_dim = 2048
ff_L1 = nn.Linear(dim, hidden_dim)
ff = ff_L1(ff_in)
print("ff_Linear1:", ff.shape)

GELU = nn.GELU()
ff = GELU(ff)
print("ff_GELU:", ff.shape)

ff_dropout1 = nn.Dropout(dropout_rate)
ff = ff_dropout1(ff)
print("ff_dropout1", ff.shape)

ff_L2 = nn.Linear(hidden_dim, dim)
ff = ff_L2(ff)
print("ff_Linear2:", ff.shape)

ff_dropout2 = nn.Dropout(dropout_rate)
ff = ff_dropout2(ff)
print("ff_dropout2", ff.shape)

x = ff + x
print("x_ff_out:", x.shape)

x = x.mean(dim=1)
print("pooling:", x.shape)

to_latent = nn.Identity()
x = to_latent(x)
print("latent:", x.shape)

generator_norm = nn.LayerNorm(dim)
x = generator_norm(x)
print("generator_norm", x.shape)

dim_spec = 7
spec = torch.randn(1, dim_spec)
print("spec:", spec.shape)

spec_tokens = repeat(spec, '() d -> b d', b = b)
print("spec_tokens:", spec_tokens.shape)

x = torch.cat((spec_tokens, x), dim=1)
print("cat(spec_tokens,x):", x.shape)

generator = nn.Linear(dim+dim_spec, num_pred)
pred = generator(x)
print("model_out:", pred.shape)

inp: torch.Size([1, 40, 40, 20, 12])
patch_reshape: torch.Size([1, 32000, 12])
patch_Linear: torch.Size([1, 32000, 512])
pred_token: torch.Size([1, 101, 512])
pred_tokens: torch.Size([1, 101, 512])
cat(pred_tokens,x): torch.Size([1, 32101, 512])
pos_emb: torch.Size([1, 32101, 512])
pos_dropout: torch.Size([1, 32101, 512])
PreNorm_Attn torch.Size([1, 32101, 512])
qkv: torch.Size([1, 32101, 30720])
qkv[0]: torch.Size([1, 32101, 10240])
qkv[1]: torch.Size([1, 32101, 10240])
qkv[2]: torch.Size([1, 32101, 10240])
q: torch.Size([1, 20, 32101, 512])
k: torch.Size([1, 20, 32101, 512])
v: torch.Size([1, 20, 32101, 512])


In [23]:
# Mazda_? 変数を1次元に並べる

inp = torch.rand(1,40,40,20,12)
print("inp:", inp.shape)

x_dim = 40
y_dim = 40
z_dim = 20
data_dim = 12

patch_dim = x_dim * y_dim * z_dim

patch_reshape = Rearrange('b x y z data -> b (x y z) data', x = x_dim, y = y_dim, z = z_dim, data = data_dim)
x = patch_reshape(inp)
print("patch_reshape:", x.shape)

dim = 512

patch_L = nn.Linear(data_dim, dim)
x = patch_L(x)
print("patch_Linear:", x.shape)

b, n, _ = x.shape
num_pred = 101

pred_token = nn.Parameter(torch.randn(1, num_pred, dim))
print("pred_token:", pred_token.shape)

pred_tokens = repeat(pred_token, '() n d -> b n d', b = b)
print("pred_tokens:", pred_tokens.shape)

x = torch.cat((pred_tokens, x), dim=1)
print("cat(pred_tokens,x):", x.shape)

pos_emb = nn.Parameter(torch.randn(1, patch_dim + num_pred, dim))
x += pos_emb[:,:(n+num_pred)]
print("pos_emb:", x.shape)

dropout_rate = 0.1
pos_dropout = nn.Dropout(dropout_rate)
x = pos_dropout(x)
print("pos_dropout:", x.shape)

attn_norm = nn.LayerNorm(dim)
attn_in = attn_norm(x)
print("PreNorm_Attn", attn_in.shape)

dim_head = 512
heads = 20
inner_dim = dim_head * heads

scale = dim_head ** -0.5

to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
print("qkv:", to_qkv(x).shape)
qkv = to_qkv(x).chunk(3, dim = -1)
print("qkv[0]:", qkv[0].shape)
print("qkv[1]:", qkv[1].shape)
print("qkv[2]:", qkv[2].shape)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = heads), qkv)
print("q:", q.shape)
print("k:", k.shape)
print("v:", v.shape)

dots = torch.matmul(q, k.transpose(-1, -2)) * scale
print("dots:", dots.shape)

attend = nn.Softmax(dim = -1)
attn = attend(dots)
print("attn:", attn.shape)

attn_score = torch.matmul(attn, v)
print("attn_score:", attn_score.shape)

attn_score = rearrange(attn_score, 'b h n d -> b n (h d)')
print("attn_reshape", attn_score.shape)

attn_L = nn.Linear(inner_dim, dim)
attn_out = attn_L(attn_score)
print("attn_Linear:", attn_out.shape)

attn_dropout = nn.Dropout(dropout_rate)
attn_out = attn_dropout(attn_out)
print("attn_out:", attn_out.shape)

x = attn_out + x
print("x_attn_out:", x.shape)

ff_norm = nn.LayerNorm(dim)
ff_in = ff_norm(x)
print("PreNorm_Attn", ff_in.shape)

hidden_dim = 2048
ff_L1 = nn.Linear(dim, hidden_dim)
ff = ff_L1(ff_in)
print("ff_Linear1:", ff.shape)

GELU = nn.GELU()
ff = GELU(ff)
print("ff_GELU:", ff.shape)

ff_dropout1 = nn.Dropout(dropout_rate)
ff = ff_dropout1(ff)
print("ff_dropout1", ff.shape)

ff_L2 = nn.Linear(hidden_dim, dim)
ff = ff_L2(ff)
print("ff_Linear2:", ff.shape)

ff_dropout2 = nn.Dropout(dropout_rate)
ff = ff_dropout2(ff)
print("ff_dropout2", ff.shape)

x = ff + x
print("x_ff_out:", x.shape)

x = x.mean(dim=1)
print("pooling:", x.shape)

to_latent = nn.Identity()
x = to_latent(x)
print("latent:", x.shape)

generator_norm = nn.LayerNorm(dim)
x = generator_norm(x)
print("generator_norm", x.shape)

dim_spec = 7
spec = torch.randn(1, dim_spec)
print("spec:", spec.shape)

spec_tokens = repeat(spec, '() d -> b d', b = b)
print("spec_tokens:", spec_tokens.shape)

x = torch.cat((spec_tokens, x), dim=1)
print("cat(spec_tokens,x):", x.shape)

generator = nn.Linear(dim+dim_spec, num_pred)
pred = generator(x)
print("model_out:", pred.shape)

inp: torch.Size([2, 40, 40, 20, 12])
patch_reshape: torch.Size([2, 32000, 12])
patch_Linear: torch.Size([2, 32000, 512])
pred_token: torch.Size([1, 101, 512])
pred_tokens: torch.Size([2, 101, 512])
cat(pred_tokens,x): torch.Size([2, 32101, 512])
pos_emb: torch.Size([2, 32101, 512])
pos_dropout: torch.Size([2, 32101, 512])
PreNorm_Attn torch.Size([2, 32101, 512])
qkv: torch.Size([2, 32101, 30720])
qkv[0]: torch.Size([2, 32101, 10240])
qkv[1]: torch.Size([2, 32101, 10240])
qkv[2]: torch.Size([2, 32101, 10240])
q: torch.Size([2, 20, 32101, 512])
k: torch.Size([2, 20, 32101, 512])
v: torch.Size([2, 20, 32101, 512])


RuntimeError: [enforce fail at CPUAllocator.cpp:68] . DefaultCPUAllocator: can't allocate memory: you tried to allocate 164875872160 bytes. Error code 12 (Cannot allocate memory)

In [None]:
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)

assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width

inp: torch.Size([1, 40, 40, 20, 12])


In [29]:
# パッチ分割

patch_split_size = [1,1,20]

inp = torch.rand(1,40,40,20,12)
print("inp:", inp.shape)

_, x_dim, y_dim, z_dim, data_dim = inp.shape
print("x_dim:%i, y_dim:%i, z_dim:%i" % (x_dim, y_dim, z_dim))
patch_x = x_dim // patch_split_size[0] 
patch_y = y_dim // patch_split_size[1]
patch_z = z_dim // patch_split_size[2]
print("patch_x:%i, patch_y:%i, patch_z:%i" % (patch_x, patch_y, patch_z))

num_patches = (x_dim // patch_x) * (y_dim // patch_y) * (z_dim // patch_z)

patch_dim = patch_x * patch_y * patch_z * data_dim

patch_reshape = Rearrange('b (x p1) (y p2) (z p3) data -> b (x y z) (p1 p2 p3 data)', p1=patch_x, p2=patch_y, p3=patch_z)  
x = patch_reshape(inp)
print("patch_reshape:", x.shape)

dim = 512

patch_L = nn.Linear(patch_dim, dim)
x = patch_L(x)
print("patch_Linear:", x.shape)

b, n, _ = x.shape
num_pred = 101

pred_token = nn.Parameter(torch.randn(1, num_pred, dim))
print("pred_token:", pred_token.shape)

pred_tokens = repeat(pred_token, '() n d -> b n d', b = b)
print("pred_tokens:", pred_tokens.shape)

x = torch.cat((pred_tokens, x), dim=1)
print("cat(pred_tokens,x):", x.shape)

pos_emb = nn.Parameter(torch.randn(1, num_patches + num_pred, dim))
x += pos_emb[:,:(n+num_pred)]
print("pos_emb:", x.shape)

dropout_rate = 0.1
pos_dropout = nn.Dropout(dropout_rate)
x = pos_dropout(x)
print("pos_dropout:", x.shape)

attn_norm = nn.LayerNorm(dim)
attn_in = attn_norm(x)
print("PreNorm_Attn", attn_in.shape)



#dim_head = 512
#heads = num_patches
dim_head = num_patches
heads = 8
inner_dim = dim_head * heads

scale = dim_head ** -0.5

to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
print("qkv:", to_qkv(x).shape)
qkv = to_qkv(x).chunk(3, dim = -1)
print("qkv[0]:", qkv[0].shape)
print("qkv[1]:", qkv[1].shape)
print("qkv[2]:", qkv[2].shape)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = heads), qkv)
print("q:", q.shape)
print("k:", k.shape)
print("v:", v.shape)

dots = torch.matmul(q, k.transpose(-1, -2)) * scale
print("dots:", dots.shape)

attend = nn.Softmax(dim = -1)
attn = attend(dots)
print("attn:", attn.shape)

attn_score = torch.matmul(attn, v)
print("attn_score:", attn_score.shape)

attn_score = rearrange(attn_score, 'b h n d -> b n (h d)')
print("attn_reshape", attn_score.shape)

attn_L = nn.Linear(inner_dim, dim)
attn_out = attn_L(attn_score)
print("attn_Linear:", attn_out.shape)

attn_dropout = nn.Dropout(dropout_rate)
attn_out = attn_dropout(attn_out)
print("attn_out:", attn_out.shape)

x = attn_out + x
print("x_attn_out:", x.shape)

ff_norm = nn.LayerNorm(dim)
ff_in = ff_norm(x)
print("PreNorm_ff", ff_in.shape)

hidden_dim = 2048
ff_L1 = nn.Linear(dim, hidden_dim)
ff = ff_L1(ff_in)
print("ff_Linear1:", ff.shape)

GELU = nn.GELU()
ff = GELU(ff)
print("ff_GELU:", ff.shape)

ff_dropout1 = nn.Dropout(dropout_rate)
ff = ff_dropout1(ff)
print("ff_dropout1", ff.shape)

ff_L2 = nn.Linear(hidden_dim, dim)
ff = ff_L2(ff)
print("ff_Linear2:", ff.shape)

ff_dropout2 = nn.Dropout(dropout_rate)
ff = ff_dropout2(ff)
print("ff_dropout2", ff.shape)

x = ff + x
print("x_ff_out:", x.shape)

x = x.mean(dim=1)
print("pooling:", x.shape)

to_latent = nn.Identity()
x = to_latent(x)
print("latent:", x.shape)

generator_norm = nn.LayerNorm(dim)
x = generator_norm(x)
print("generator_norm", x.shape)

dim_spec = 7
spec = torch.randn(1, dim_spec)
print("spec:", spec.shape)

spec_tokens = repeat(spec, '() d -> b d', b = b)
print("spec_tokens:", spec_tokens.shape)

x = torch.cat((spec_tokens, x), dim=1)
print("cat(spec_tokens,x):", x.shape)

generator = nn.Linear(dim+dim_spec, num_pred)
pred = generator(x)
print("model_out:", pred.shape)



inp: torch.Size([1, 40, 40, 20, 12])
x_dim:40, y_dim:40, z_dim:20
patch_x:40, patch_y:40, patch_z:1
patch_reshape: torch.Size([1, 20, 19200])
patch_Linear: torch.Size([1, 20, 512])
pred_token: torch.Size([1, 101, 512])
pred_tokens: torch.Size([1, 101, 512])
cat(pred_tokens,x): torch.Size([1, 121, 512])
pos_emb: torch.Size([1, 121, 512])
pos_dropout: torch.Size([1, 121, 512])
PreNorm_Attn torch.Size([1, 121, 512])
qkv: torch.Size([1, 121, 480])
qkv[0]: torch.Size([1, 121, 160])
qkv[1]: torch.Size([1, 121, 160])
qkv[2]: torch.Size([1, 121, 160])
q: torch.Size([1, 8, 121, 20])
k: torch.Size([1, 8, 121, 20])
v: torch.Size([1, 8, 121, 20])
dots: torch.Size([1, 8, 121, 121])
attn: torch.Size([1, 8, 121, 121])
attn_score: torch.Size([1, 8, 121, 20])
attn_reshape torch.Size([1, 121, 160])
attn_Linear: torch.Size([1, 121, 512])
attn_out: torch.Size([1, 121, 512])
x_attn_out: torch.Size([1, 121, 512])
PreNorm_ff torch.Size([1, 121, 512])
ff_Linear1: torch.Size([1, 121, 2048])
ff_GELU: torch.Size

In [34]:
# パッチ分割 spec

patch_split_size = [1,1,20]

inp = torch.rand(1,40,40,20,12)
print("inp:", inp.shape)

_, x_dim, y_dim, z_dim, data_dim = inp.shape
print("x_dim:%i, y_dim:%i, z_dim:%i" % (x_dim, y_dim, z_dim))
patch_x = x_dim // patch_split_size[0] 
patch_y = y_dim // patch_split_size[1]
patch_z = z_dim // patch_split_size[2]
print("patch_x:%i, patch_y:%i, patch_z:%i" % (patch_x, patch_y, patch_z))

num_patches = (x_dim // patch_x) * (y_dim // patch_y) * (z_dim // patch_z)

patch_dim = patch_x * patch_y * patch_z * data_dim

patch_reshape = Rearrange('b (x p1) (y p2) (z p3) data -> b (x y z) (p1 p2 p3 data)', p1=patch_x, p2=patch_y, p3=patch_z)  
x = patch_reshape(inp)
print("patch_reshape:", x.shape)

dim = 512

patch_L = nn.Linear(patch_dim, dim)
x = patch_L(x)
print("patch_Linear:", x.shape)

b, n, _ = x.shape

num_pred = 101

pred_token = nn.Parameter(torch.randn(1, num_pred, dim))
print("pred_token:", pred_token.shape)

pred_tokens = repeat(pred_token, '() n d -> b n d', b = b)
print("pred_tokens:", pred_tokens.shape)

x = torch.cat((pred_tokens, x), dim=1)
print("cat(pred_tokens,x):", x.shape)

num_spec = 7
spec = torch.randn(1, num_spec)
print("spec:", spec.shape)


spec = torch.unsqueeze(spec,0)
print("spec:", spec.shape)

spec = torch.unsqueeze(spec,0)
print("spec:", spec.shape)


spec_L1 = nn.Linear(1,dim)
spec1 = spec_L1(spec[:,:,:,0])
print("spec1_Linear:", spec1.shape)

x = torch.cat((x, spec1), dim=1)
print("cat(x, spec_1):", x.shape)


spec_L2 = nn.Linear(1,dim)
spec2 = spec_L2(spec[:,:,:,1])
print("spec2_Linear:", spec.shape)

x = torch.cat((x, spec2), dim=1)
print("cat(x, spec_2):", x.shape)


spec_L3 = nn.Linear(1,dim)
spec3 = spec_L3(spec[:,:,:,2])
print("spec3_Linear:", spec.shape)

x = torch.cat((x, spec3), dim=1)
print("cat(x, spec_3):", x.shape)


spec_L4 = nn.Linear(1,dim)
spec4 = spec_L4(spec[:,:,:,3])
print("spec4_Linear:", spec.shape)

x = torch.cat((x, spec4), dim=1)
print("cat(x, spec_4):", x.shape)


spec_L5 = nn.Linear(1,dim)
spec5 = spec_L5(spec[:,:,:,4])
print("spec5_Linear:", spec.shape)

x = torch.cat((x, spec5), dim=1)
print("cat(x, spec_5):", x.shape)


spec_L6 = nn.Linear(1,dim)
spec6 = spec_L6(spec[:,:,:,5])
print("spec_Linear:", spec.shape)

x = torch.cat((x, spec6), dim=1)
print("cat(x, spec_6):", x.shape)


spec_L7 = nn.Linear(1,dim)
spec7 = spec_L7(spec[:,:,:,6])
print("spec_Linear:", spec.shape)

x = torch.cat((x, spec7), dim=1)
print("cat(x, spec_7):", x.shape)


pos_emb = nn.Parameter(torch.randn(1, num_patches + num_pred + num_spec, dim))
x += pos_emb[:,:(n+num_pred+num_spec)]
print("pos_emb:", x.shape)

dropout_rate = 0.1
pos_dropout = nn.Dropout(dropout_rate)
x = pos_dropout(x)
print("pos_dropout:", x.shape)

attn_norm = nn.LayerNorm(dim)
attn_in = attn_norm(x)
print("PreNorm_Attn", attn_in.shape)

#dim_head = 512
#heads = num_patches
dim_head = num_patches
heads = 8
inner_dim = dim_head * heads

scale = dim_head ** -0.5

to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
print("qkv:", to_qkv(x).shape)
qkv = to_qkv(x).chunk(3, dim = -1)
print("qkv[0]:", qkv[0].shape)
print("qkv[1]:", qkv[1].shape)
print("qkv[2]:", qkv[2].shape)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = heads), qkv)
print("q:", q.shape)
print("k:", k.shape)
print("v:", v.shape)

dots = torch.matmul(q, k.transpose(-1, -2)) * scale
print("dots:", dots.shape)

attend = nn.Softmax(dim = -1)
attn = attend(dots)
print("attn:", attn.shape)

attn_score = torch.matmul(attn, v)
print("attn_score:", attn_score.shape)

attn_score = rearrange(attn_score, 'b h n d -> b n (h d)')
print("attn_reshape", attn_score.shape)

attn_L = nn.Linear(inner_dim, dim)
attn_out = attn_L(attn_score)
print("attn_Linear:", attn_out.shape)

attn_dropout = nn.Dropout(dropout_rate)
attn_out = attn_dropout(attn_out)
print("attn_out:", attn_out.shape)

x = attn_out + x
print("x_attn_out:", x.shape)

ff_norm = nn.LayerNorm(dim)
ff_in = ff_norm(x)
print("PreNorm_ff", ff_in.shape)

hidden_dim = 2048
ff_L1 = nn.Linear(dim, hidden_dim)
ff = ff_L1(ff_in)
print("ff_Linear1:", ff.shape)

GELU = nn.GELU()
ff = GELU(ff)
print("ff_GELU:", ff.shape)

ff_dropout1 = nn.Dropout(dropout_rate)
ff = ff_dropout1(ff)
print("ff_dropout1", ff.shape)

ff_L2 = nn.Linear(hidden_dim, dim)
ff = ff_L2(ff)
print("ff_Linear2:", ff.shape)

ff_dropout2 = nn.Dropout(dropout_rate)
ff = ff_dropout2(ff)
print("ff_dropout2", ff.shape)

x = ff + x
print("x_ff_out:", x.shape)

x = x.mean(dim=1)
print("pooling:", x.shape)

to_latent = nn.Identity()
x = to_latent(x)
print("latent:", x.shape)

generator_norm = nn.LayerNorm(dim)
x = generator_norm(x)
print("generator_norm", x.shape)

generator = nn.Linear(dim, num_pred)
pred = generator(x)
print("model_out:", pred.shape)



inp: torch.Size([1, 40, 40, 20, 12])
x_dim:40, y_dim:40, z_dim:20
patch_x:40, patch_y:40, patch_z:1
patch_reshape: torch.Size([1, 20, 19200])
patch_Linear: torch.Size([1, 20, 512])
pred_token: torch.Size([1, 101, 512])
pred_tokens: torch.Size([1, 101, 512])
cat(pred_tokens,x): torch.Size([1, 121, 512])
spec: torch.Size([1, 7])
spec: torch.Size([1, 1, 7])
spec: torch.Size([1, 1, 1, 7])
spec1_Linear: torch.Size([1, 1, 512])
cat(x, spec_1): torch.Size([1, 122, 512])
spec2_Linear: torch.Size([1, 1, 1, 7])
cat(x, spec_2): torch.Size([1, 123, 512])
spec3_Linear: torch.Size([1, 1, 1, 7])
cat(x, spec_3): torch.Size([1, 124, 512])
spec4_Linear: torch.Size([1, 1, 1, 7])
cat(x, spec_4): torch.Size([1, 125, 512])
spec5_Linear: torch.Size([1, 1, 1, 7])
cat(x, spec_5): torch.Size([1, 126, 512])
spec_Linear: torch.Size([1, 1, 1, 7])
cat(x, spec_6): torch.Size([1, 127, 512])
spec_Linear: torch.Size([1, 1, 1, 7])
cat(x, spec_7): torch.Size([1, 128, 512])
pos_emb: torch.Size([1, 128, 512])
pos_dropout: 

In [2]:
# パッチ分割 spec 完成版

patch_split_size = [1,1,20]

inp = torch.rand(1,40,40,20,12)
print("inp:", inp.shape)

_, x_dim, y_dim, z_dim, data_dim = inp.shape
print("x_dim:%i, y_dim:%i, z_dim:%i" % (x_dim, y_dim, z_dim))
patch_x = x_dim // patch_split_size[0] 
patch_y = y_dim // patch_split_size[1]
patch_z = z_dim // patch_split_size[2]
print("patch_x:%i, patch_y:%i, patch_z:%i" % (patch_x, patch_y, patch_z))

num_patches = (x_dim // patch_x) * (y_dim // patch_y) * (z_dim // patch_z)

patch_dim = patch_x * patch_y * patch_z * data_dim

patch_reshape = Rearrange('b (x p1) (y p2) (z p3) data -> b (x y z) (p1 p2 p3 data)', p1=patch_x, p2=patch_y, p3=patch_z)  
x = patch_reshape(inp)
print("patch_reshape:", x.shape)

dim = 512

patch_L = nn.Linear(patch_dim, dim)
x = patch_L(x)
print("patch_Linear:", x.shape)

b, n, _ = x.shape

num_pred = 101

pred_token = nn.Parameter(torch.randn(1, num_pred, dim))
print("pred_token:", pred_token.shape)

pred_tokens = repeat(pred_token, '() n d -> b n d', b = b)
print("pred_tokens:", pred_tokens.shape)

x = torch.cat((pred_tokens, x), dim=1)
print("cat(pred_tokens,x):", x.shape)

num_spec = 7
spec = torch.randn(1, num_spec)
print("spec:", spec.shape)

spec = torch.unsqueeze(spec,1)
print("spec:", spec.shape)

spec = torch.unsqueeze(spec,1)
print("spec:", spec.shape)

def clones(module, n):
    """
    Produce N identical layers.
    """
    assert isinstance(module, nn.Module)
    return nn.ModuleList([deepcopy(module) for _ in range(n)])

spec_Ls = clones(nn.Linear(1,dim), num_spec)


for i, spec_L in enumerate(spec_Ls):
    if(i == 0):
        spec_emb_all = spec_L(spec[:,:,:,i])
        print(f'spec{i}_Linear:{spec_emb_all.shape}')
    else:
        spec_emb = spec_L(spec[:,:,:,i])
        spec_emb_all = torch.cat((spec_emb_all, spec_emb), dim=1)
        print(f'spec{i}_Linear:{spec_emb_all.shape}')


x = torch.cat((x, spec_emb_all), dim=1)
print("cat(x, spec_emb):", x.shape)


pos_emb = nn.Parameter(torch.randn(1, num_patches + num_pred + num_spec, dim))
x += pos_emb[:,:(n+num_pred+num_spec)]
print("pos_emb:", x.shape)

dropout_rate = 0.1
pos_dropout = nn.Dropout(dropout_rate)
x = pos_dropout(x)
print("pos_dropout:", x.shape)

attn_norm = nn.LayerNorm(dim)
attn_in = attn_norm(x)
print("PreNorm_Attn", attn_in.shape)

#dim_head = 512
#heads = num_patches
dim_head = num_patches
heads = 8
inner_dim = dim_head * heads

scale = dim_head ** -0.5

to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
print("qkv:", to_qkv(x).shape)
qkv = to_qkv(x).chunk(3, dim = -1)
print("qkv[0]:", qkv[0].shape)
print("qkv[1]:", qkv[1].shape)
print("qkv[2]:", qkv[2].shape)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = heads), qkv)
print("q:", q.shape)
print("k:", k.shape)
print("v:", v.shape)

dots = torch.matmul(q, k.transpose(-1, -2)) * scale
print("dots:", dots.shape)

attend = nn.Softmax(dim = -1)
attn = attend(dots)
print("attn:", attn.shape)

attn_score = torch.matmul(attn, v)
print("attn_score:", attn_score.shape)

attn_score = rearrange(attn_score, 'b h n d -> b n (h d)')
print("attn_reshape", attn_score.shape)

attn_L = nn.Linear(inner_dim, dim)
attn_out = attn_L(attn_score)
print("attn_Linear:", attn_out.shape)

attn_dropout = nn.Dropout(dropout_rate)
attn_out = attn_dropout(attn_out)
print("attn_out:", attn_out.shape)

x = attn_out + x
print("x_attn_out:", x.shape)

ff_norm = nn.LayerNorm(dim)
ff_in = ff_norm(x)
print("PreNorm_ff", ff_in.shape)

hidden_dim = 2048
ff_L1 = nn.Linear(dim, hidden_dim)
ff = ff_L1(ff_in)
print("ff_Linear1:", ff.shape)

GELU = nn.GELU()
ff = GELU(ff)
print("ff_GELU:", ff.shape)

ff_dropout1 = nn.Dropout(dropout_rate)
ff = ff_dropout1(ff)
print("ff_dropout1", ff.shape)

ff_L2 = nn.Linear(hidden_dim, dim)
ff = ff_L2(ff)
print("ff_Linear2:", ff.shape)

ff_dropout2 = nn.Dropout(dropout_rate)
ff = ff_dropout2(ff)
print("ff_dropout2", ff.shape)

x = ff + x
print("x_ff_out:", x.shape)

x = x.mean(dim=1)
print("pooling:", x.shape)

to_latent = nn.Identity()
x = to_latent(x)
print("latent:", x.shape)

generator_norm = nn.LayerNorm(dim)
x = generator_norm(x)
print("generator_norm", x.shape)

generator = nn.Linear(dim, num_pred)
pred = generator(x)
print("model_out:", pred.shape)



inp: torch.Size([1, 40, 40, 20, 12])
x_dim:40, y_dim:40, z_dim:20
patch_x:40, patch_y:40, patch_z:1
patch_reshape: torch.Size([1, 20, 19200])
patch_Linear: torch.Size([1, 20, 512])
pred_token: torch.Size([1, 101, 512])
pred_tokens: torch.Size([1, 101, 512])
cat(pred_tokens,x): torch.Size([1, 121, 512])
spec: torch.Size([1, 7])
spec: torch.Size([1, 1, 7])
spec: torch.Size([1, 1, 1, 7])
spec0_Linear:torch.Size([1, 1, 512])
spec1_Linear:torch.Size([1, 2, 512])
spec2_Linear:torch.Size([1, 3, 512])
spec3_Linear:torch.Size([1, 4, 512])
spec4_Linear:torch.Size([1, 5, 512])
spec5_Linear:torch.Size([1, 6, 512])
spec6_Linear:torch.Size([1, 7, 512])
cat(x, spec_emb): torch.Size([1, 128, 512])
pos_emb: torch.Size([1, 128, 512])
pos_dropout: torch.Size([1, 128, 512])
PreNorm_Attn torch.Size([1, 128, 512])
qkv: torch.Size([1, 128, 480])
qkv[0]: torch.Size([1, 128, 160])
qkv[1]: torch.Size([1, 128, 160])
qkv[2]: torch.Size([1, 128, 160])
q: torch.Size([1, 8, 128, 20])
k: torch.Size([1, 8, 128, 20])
v

In [4]:
# パッチ分割 spec 完成版2

patch_split_size = [1,1,20]

inp = torch.rand(1,40,40,20,12)
print("inp:", inp.shape)

_, x_dim, y_dim, z_dim, data_dim = inp.shape
print("x_dim:%i, y_dim:%i, z_dim:%i" % (x_dim, y_dim, z_dim))
patch_x = x_dim // patch_split_size[0] 
patch_y = y_dim // patch_split_size[1]
patch_z = z_dim // patch_split_size[2]
print("patch_x:%i, patch_y:%i, patch_z:%i" % (patch_x, patch_y, patch_z))

num_patches = (x_dim // patch_x) * (y_dim // patch_y) * (z_dim // patch_z)

patch_dim = patch_x * patch_y * patch_z * data_dim

patch_reshape = Rearrange('b (x p1) (y p2) (z p3) data -> b (x y z) (p1 p2 p3 data)', p1=patch_x, p2=patch_y, p3=patch_z)  
x = patch_reshape(inp)
print("patch_reshape:", x.shape)

dim = 512

patch_L = nn.Linear(patch_dim, dim)
x = patch_L(x)
print("patch_Linear:", x.shape)

b, n, _ = x.shape

num_pred = 101

pred_token = nn.Parameter(torch.randn(1, num_pred, dim))
print("pred_token:", pred_token.shape)

pred_tokens = repeat(pred_token, '() n d -> b n d', b = b)
print("pred_tokens:", pred_tokens.shape)

x = torch.cat((pred_tokens, x), dim=1)
print("cat(pred_tokens,x):", x.shape)

pos_emb = nn.Parameter(torch.randn(1, num_patches + num_pred, dim))
x += pos_emb[:,:(n+num_pred)]
print("pos_emb:", x.shape)

dropout_rate = 0.1
pos_dropout = nn.Dropout(dropout_rate)
x = pos_dropout(x)
print("pos_dropout:", x.shape)

attn_norm = nn.LayerNorm(dim)
attn_in = attn_norm(x)
print("PreNorm_Attn", attn_in.shape)

#dim_head = 512
#heads = num_patches
dim_head = num_patches
heads = 8
inner_dim = dim_head * heads

scale = dim_head ** -0.5

to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
print("qkv:", to_qkv(x).shape)
qkv = to_qkv(x).chunk(3, dim = -1)
print("qkv[0]:", qkv[0].shape)
print("qkv[1]:", qkv[1].shape)
print("qkv[2]:", qkv[2].shape)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = heads), qkv)
print("q:", q.shape)
print("k:", k.shape)
print("v:", v.shape)

dots = torch.matmul(q, k.transpose(-1, -2)) * scale
print("dots:", dots.shape)

attend = nn.Softmax(dim = -1)
attn = attend(dots)
print("attn:", attn.shape)

attn_score = torch.matmul(attn, v)
print("attn_score:", attn_score.shape)

attn_score = rearrange(attn_score, 'b h n d -> b n (h d)')
print("attn_reshape", attn_score.shape)

attn_L = nn.Linear(inner_dim, dim)
attn_out = attn_L(attn_score)
print("attn_Linear:", attn_out.shape)

attn_dropout = nn.Dropout(dropout_rate)
attn_out = attn_dropout(attn_out)
print("attn_out:", attn_out.shape)

x = attn_out + x
print("x_attn_out:", x.shape)

ff_norm = nn.LayerNorm(dim)
ff_in = ff_norm(x)
print("PreNorm_ff", ff_in.shape)

hidden_dim = 2048
ff_L1 = nn.Linear(dim, hidden_dim)
ff = ff_L1(ff_in)
print("ff_Linear1:", ff.shape)

GELU = nn.GELU()
ff = GELU(ff)
print("ff_GELU:", ff.shape)

ff_dropout1 = nn.Dropout(dropout_rate)
ff = ff_dropout1(ff)
print("ff_dropout1", ff.shape)

ff_L2 = nn.Linear(hidden_dim, dim)
ff = ff_L2(ff)
print("ff_Linear2:", ff.shape)

ff_dropout2 = nn.Dropout(dropout_rate)
ff = ff_dropout2(ff)
print("ff_dropout2", ff.shape)

x = ff + x
print("x_ff_out:", x.shape)

num_spec = 7
spec = torch.randn(1, num_spec)
print("spec:", spec.shape)

spec = torch.unsqueeze(spec,1)
print("spec:", spec.shape)

spec = torch.unsqueeze(spec,1)
print("spec:", spec.shape)

def clones(module, n):
    """
    Produce N identical layers.
    """
    assert isinstance(module, nn.Module)
    return nn.ModuleList([deepcopy(module) for _ in range(n)])

spec_Ls = clones(nn.Linear(1,dim), num_spec)


for i, spec_L in enumerate(spec_Ls):
    if(i == 0):
        spec_emb_all = spec_L(spec[:,:,:,i])
        print(f'spec{i}_Linear:{spec_emb_all.shape}')
    else:
        spec_emb = spec_L(spec[:,:,:,i])
        spec_emb_all = torch.cat((spec_emb_all, spec_emb), dim=1)
        print(f'spec{i}_Linear:{spec_emb_all.shape}')


x = torch.cat((x, spec_emb_all), dim=1)
print("cat(x, spec_emb):", x.shape)

x = x.mean(dim=1)
print("pooling:", x.shape)

to_latent = nn.Identity()
x = to_latent(x)
print("latent:", x.shape)

generator_norm = nn.LayerNorm(dim)
x = generator_norm(x)
print("generator_norm", x.shape)

generator = nn.Linear(dim, num_pred)
pred = generator(x)
print("model_out:", pred.shape)



inp: torch.Size([2, 40, 40, 20, 12])
x_dim:40, y_dim:40, z_dim:20
patch_x:40, patch_y:40, patch_z:1
patch_reshape: torch.Size([2, 20, 19200])
patch_Linear: torch.Size([2, 20, 512])
pred_token: torch.Size([1, 101, 512])
pred_tokens: torch.Size([2, 101, 512])
cat(pred_tokens,x): torch.Size([2, 121, 512])
pos_emb: torch.Size([2, 121, 512])
pos_dropout: torch.Size([2, 121, 512])
PreNorm_Attn torch.Size([2, 121, 512])
qkv: torch.Size([2, 121, 480])
qkv[0]: torch.Size([2, 121, 160])
qkv[1]: torch.Size([2, 121, 160])
qkv[2]: torch.Size([2, 121, 160])
q: torch.Size([2, 8, 121, 20])
k: torch.Size([2, 8, 121, 20])
v: torch.Size([2, 8, 121, 20])
dots: torch.Size([2, 8, 121, 121])
attn: torch.Size([2, 8, 121, 121])
attn_score: torch.Size([2, 8, 121, 20])
attn_reshape torch.Size([2, 121, 160])
attn_Linear: torch.Size([2, 121, 512])
attn_out: torch.Size([2, 121, 512])
x_attn_out: torch.Size([2, 121, 512])
PreNorm_ff torch.Size([2, 121, 512])
ff_Linear1: torch.Size([2, 121, 2048])
ff_GELU: torch.Size

In [4]:
#convolution

patch_split_size = [1,1,20]

inp = torch.rand(1,40,40,20,12)
print("inp:", inp.shape)

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv_block = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=2),
        )

    def forward(self, x):
        return self.conv_block(x)

x = inp.permute(0, 4, 1, 2, 3)
print("permute:", x.shape)
    
conv1 = ConvBlock(12, 12)
x = conv1(x)
print("conv1:", x.shape)

conv2 = ConvBlock(12, 12)
x = conv2(x)
print("conv2:", x.shape)

#conv3 = ConvBlock(12, 12)
#x = conv3(x)
#print("conv3:", x.shape)

x = x.permute(0, 2, 3, 4, 1)
print("permute:", x.shape)

_, x_dim, y_dim, z_dim, data_dim = x.shape

num_patches = x_dim * y_dim * z_dim

patch_reshape = Rearrange('b x y z data -> b (x y z) data')  
x = patch_reshape(x)
print("patch_reshape:", x.shape)

dim = 512

patch_L = nn.Linear(data_dim, dim)
x = patch_L(x)
print("patch_Linear:", x.shape)

b, n, _ = x.shape

num_pred = 101

pred_token = nn.Parameter(torch.randn(1, num_pred, dim))
print("pred_token:", pred_token.shape)

pred_tokens = repeat(pred_token, '() n d -> b n d', b = b)
print("pred_tokens:", pred_tokens.shape)

x = torch.cat((pred_tokens, x), dim=1)
print("cat(pred_tokens,x):", x.shape)

pos_emb = nn.Parameter(torch.randn(1, num_patches + num_pred, dim))
x += pos_emb[:,:(n+num_pred)]
print("pos_emb:", x.shape)

dropout_rate = 0.1
pos_dropout = nn.Dropout(dropout_rate)
x = pos_dropout(x)
print("pos_dropout:", x.shape)

attn_norm = nn.LayerNorm(dim)
attn_in = attn_norm(x)
print("PreNorm_Attn", attn_in.shape)

#dim_head = 512
#heads = num_patches
dim_head = num_patches
heads = 8
inner_dim = dim_head * heads

scale = dim_head ** -0.5

to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
print("qkv:", to_qkv(x).shape)
qkv = to_qkv(x).chunk(3, dim = -1)
print("qkv[0]:", qkv[0].shape)
print("qkv[1]:", qkv[1].shape)
print("qkv[2]:", qkv[2].shape)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = heads), qkv)
print("q:", q.shape)
print("k:", k.shape)
print("v:", v.shape)

dots = torch.matmul(q, k.transpose(-1, -2)) * scale
print("dots:", dots.shape)

attend = nn.Softmax(dim = -1)
attn = attend(dots)
print("attn:", attn.shape)

attn_score = torch.matmul(attn, v)
print("attn_score:", attn_score.shape)

attn_score = rearrange(attn_score, 'b h n d -> b n (h d)')
print("attn_reshape", attn_score.shape)

attn_L = nn.Linear(inner_dim, dim)
attn_out = attn_L(attn_score)
print("attn_Linear:", attn_out.shape)

attn_dropout = nn.Dropout(dropout_rate)
attn_out = attn_dropout(attn_out)
print("attn_out:", attn_out.shape)

x = attn_out + x
print("x_attn_out:", x.shape)

ff_norm = nn.LayerNorm(dim)
ff_in = ff_norm(x)
print("PreNorm_ff", ff_in.shape)

hidden_dim = 2048
ff_L1 = nn.Linear(dim, hidden_dim)
ff = ff_L1(ff_in)
print("ff_Linear1:", ff.shape)

GELU = nn.GELU()
ff = GELU(ff)
print("ff_GELU:", ff.shape)

ff_dropout1 = nn.Dropout(dropout_rate)
ff = ff_dropout1(ff)
print("ff_dropout1", ff.shape)

ff_L2 = nn.Linear(hidden_dim, dim)
ff = ff_L2(ff)
print("ff_Linear2:", ff.shape)

ff_dropout2 = nn.Dropout(dropout_rate)
ff = ff_dropout2(ff)
print("ff_dropout2", ff.shape)

x = ff + x
print("x_ff_out:", x.shape)

num_spec = 7
spec = torch.randn(1, num_spec)
print("spec:", spec.shape)

spec = torch.unsqueeze(spec,1)
print("spec:", spec.shape)

spec = torch.unsqueeze(spec,1)
print("spec:", spec.shape)

def clones(module, n):
    """
    Produce N identical layers.
    """
    assert isinstance(module, nn.Module)
    return nn.ModuleList([deepcopy(module) for _ in range(n)])

spec_Ls = clones(nn.Linear(1,dim), num_spec)


for i, spec_L in enumerate(spec_Ls):
    if(i == 0):
        spec_emb_all = spec_L(spec[:,:,:,i])
        print(f'spec{i}_Linear:{spec_emb_all.shape}')
    else:
        spec_emb = spec_L(spec[:,:,:,i])
        spec_emb_all = torch.cat((spec_emb_all, spec_emb), dim=1)
        print(f'spec{i}_Linear:{spec_emb_all.shape}')


x = torch.cat((x, spec_emb_all), dim=1)
print("cat(x, spec_emb):", x.shape)

x = x.mean(dim=1)
print("pooling:", x.shape)

to_latent = nn.Identity()
x = to_latent(x)
print("latent:", x.shape)

generator_norm = nn.LayerNorm(dim)
x = generator_norm(x)
print("generator_norm", x.shape)

generator = nn.Linear(dim, num_pred)
pred = generator(x)
print("model_out:", pred.shape)



inp: torch.Size([1, 40, 40, 20, 12])
permute: torch.Size([1, 12, 40, 40, 20])
conv1: torch.Size([1, 12, 20, 20, 10])
conv2: torch.Size([1, 12, 10, 10, 5])
permute: torch.Size([1, 10, 10, 5, 12])
patch_reshape: torch.Size([1, 500, 12])
patch_Linear: torch.Size([1, 500, 512])
pred_token: torch.Size([1, 101, 512])
pred_tokens: torch.Size([1, 101, 512])
cat(pred_tokens,x): torch.Size([1, 601, 512])
pos_emb: torch.Size([1, 601, 512])
pos_dropout: torch.Size([1, 601, 512])
PreNorm_Attn torch.Size([1, 601, 512])
qkv: torch.Size([1, 601, 12000])
qkv[0]: torch.Size([1, 601, 4000])
qkv[1]: torch.Size([1, 601, 4000])
qkv[2]: torch.Size([1, 601, 4000])
q: torch.Size([1, 8, 601, 500])
k: torch.Size([1, 8, 601, 500])
v: torch.Size([1, 8, 601, 500])
dots: torch.Size([1, 8, 601, 601])
attn: torch.Size([1, 8, 601, 601])
attn_score: torch.Size([1, 8, 601, 500])
attn_reshape torch.Size([1, 601, 4000])
attn_Linear: torch.Size([1, 601, 512])
attn_out: torch.Size([1, 601, 512])
x_attn_out: torch.Size([1, 60