In [6]:
import torch
import numpy as np
import torch
import torchvision.models as models
from torch import nn
import numpy as np
import pandas as pd

In [27]:
data = {
    'EmpCode': ['E1', 'E2', 'E3', 'E4', 'E5'],
    'Gender': ['Male', 'Female', 'Female', 'Male', 'Male'],
    'Age': [27, 24, 29, 24, 25],
    'Department': ['Accounting', 'Sales', 'Accounting', np.nan, 'Sales'],
    'Class': [2,3,4,-1,-1]

}
df = pd.DataFrame(data)
label_col = "Class"
all_classes = pd.DataFrame(df[label_col].value_counts().keys()).iloc[:, 0]
all_classes
# all_classes, {k: v for v, k in enumerate(all_classes)}

0   -1
1    2
2    3
3    4
Name: 0, dtype: int64

In [8]:
visited = [ [ False for _ in range(4)] for _ in range(3) ]
visited

[[False, False, False, False],
 [False, False, False, False],
 [False, False, False, False]]

In [149]:

class HierarchicalSoftmax(nn.Module):

    def __init__(self, ntokens, nhid, ntokens_per_class=None, **kwargs):
        super(HierarchicalSoftmax, self).__init__()

        # Parameters
        self.ntokens = ntokens
        self.nhid = nhid

        if ntokens_per_class is None:
            ntokens_per_class = int(np.ceil(np.sqrt(ntokens)))

        self.ntokens_per_class = ntokens_per_class

        self.nclasses = int(np.ceil(self.ntokens * 1. / self.ntokens_per_class))
        self.ntokens_actual = self.nclasses * self.ntokens_per_class

        self.layer_top_W = nn.Parameter(torch.FloatTensor(self.nhid, self.nclasses), requires_grad=True)
        # print(self.layer_top_W.shape)
        self.layer_top_b = nn.Parameter(torch.FloatTensor(self.nclasses), requires_grad=True)

        self.layer_bottom_W = nn.Parameter(torch.FloatTensor(self.nclasses, self.nhid, self.ntokens_per_class),
                                           requires_grad=True)
        # print(self.layer_bottom_W.shape)
        self.layer_bottom_b = nn.Parameter(torch.FloatTensor(self.nclasses, self.ntokens_per_class), requires_grad=True)

        self.softmax = nn.Softmax(dim=1)

        self.init_weights()



    def init_weights(self):
        initrange = 0.1
        self.layer_top_W.data.uniform_(-initrange, initrange)
        self.layer_top_b.data.fill_(0)
        self.layer_bottom_W.data.uniform_(-initrange, initrange)
        self.layer_bottom_b.data.fill_(0)
        
    def _predict(self, inputs):
        batch_size, d = inputs.size()

        layer_top_logits = torch.matmul(inputs, self.layer_top_W) + self.layer_top_b
        layer_top_probs = self.softmax(layer_top_logits)
        
        label_position_top = torch.argmax(layer_top_probs, dim=1)
                
        layer_bottom_logits = torch.squeeze(
            torch.bmm(torch.unsqueeze(inputs, dim=1), self.layer_bottom_W[label_position_top]), dim=1) + \
                              self.layer_bottom_b[label_position_top]
        layer_bottom_probs = self.softmax(layer_bottom_logits)
        
        return torch.bmm(layer_top_probs.unsqueeze(2), layer_bottom_probs.unsqueeze(1)).flatten(start_dim=1)
        

    def forward(self, inputs, labels=None):
        if labels is None:
            return self._predict(inputs)
        batch_size, d = inputs.size()

        layer_top_logits = torch.matmul(inputs, self.layer_top_W) + self.layer_top_b
        layer_top_probs = self.softmax(layer_top_logits)
        
        label_position_top = (labels / self.ntokens_per_class).long()
        label_position_bottom = (labels % self.ntokens_per_class).long()
        
        # print(layer_top_probs.shape, label_position_top.shape)

        layer_bottom_logits = torch.squeeze(
            torch.bmm(torch.unsqueeze(inputs, dim=1), self.layer_bottom_W[label_position_top]), dim=1) + \
                              self.layer_bottom_b[label_position_top]
        layer_bottom_probs = self.softmax(layer_bottom_logits)

        target_probs = layer_top_probs[torch.arange(batch_size).long(), label_position_top] * layer_bottom_probs[
            torch.arange(batch_size).long(), label_position_bottom]

        # print(f"top {layer_top_probs.shape} {layer_top_probs}")
        # print(f"bottom {layer_bottom_probs.shape} {layer_bottom_probs}")
        top_indx = torch.argmax(layer_top_probs, dim=1)
        botton_indx = torch.argmax(layer_bottom_probs, dim=1)

        real_indx = (top_indx * self.ntokens_per_class) + botton_indx
        # print(top_indx, self.nclasses, botton_indx)
        # print(f"target {target_probs.shape} {target_probs}")

        # loss = -torch.mean(torch.log(target_probs.type(torch.float32) + 1e-3))
        loss = -torch.mean(torch.log(target_probs))
        with torch.no_grad():
            preds = torch.bmm(layer_top_probs.unsqueeze(2), layer_bottom_probs.unsqueeze(1)).flatten(start_dim=1)

        return loss, target_probs, layer_top_probs, layer_bottom_probs, top_indx, botton_indx, real_indx, preds



In [155]:
class ObservationTransformer(nn.Module):
    
    def __init__(self, encoder : nn.Module, nhead=10, feature_dim = 1000, output_size=128):
        super().__init__()
        self.encoder = encoder
        self.head = nn.TransformerEncoderLayer(d_model=feature_dim, nhead=nhead, batch_first=True)
        self.decoder = nn.Linear(feature_dim, output_size)
        self.feature_dim = feature_dim
        
    def forward(self, x):
        B, S, C, R, _ = x.shape 
        x = x.reshape(-1, C, R, R) # squeeze observations to (batch * sequence, channel, witdh, height) 
                                    # size into a single dim for encoding as batch
        features = self.encoder(x) # encode all images for B observations of S images each
        features = features.reshape(B, S, self.feature_dim) # resize back to observation sizes (batch, sequence, )
        out = head(features) #apply attention to teh features
        return self.decoder(out.mean(1))

    
class HObservationTransformer(nn.Module):
    
    def __init__(self, encoder : nn.Module, nhead=10, feature_dim = 1000, classes=80000, ntokens_per_class=80):
        super().__init__()
        self.obs_transformer = ObservationTransformer(encoder=encoder, nhead=nhead, feature_dim=feature_dim, output_size=128)
        self.hs =  HierarchicalSoftmax(ntokens=classes, nhid=128, ntokens_per_class=ntokens_per_class)
        # print(self.hs.nclasses, self.hs.ntokens_per_class)
        
    def forward(self, x, y=None):
        assert len(x.shape) == 5, "Observation shape should be (batch, sequence, channels, witdh, height)"
        x = self.obs_transformer(x) 
        return self.hs(x, y)

In [156]:
encoder = models.resnet50(pretrained=True)
m = ObservationTransformer(encoder=encoder, nhead=10, feature_dim=1000, output_size=128)
hm = HObservationTransformer(encoder=encoder, nhead=10, feature_dim=1000, classes=80000, ntokens_per_class=80)

In [157]:
x = torch.rand(2, 4, 3, 224, 224)
y = m(x)
y.shape

torch.Size([2, 128])

In [158]:
x = torch.rand(2, 4, 3, 224, 224)
labels = torch.ones(2).long()
loss, target_probs, layer_top_probs, layer_bottom_probs, top_indx, botton_indx, real_indx, preds = hm(x, labels)
preds.shape

torch.Size([2, 80000])

In [159]:
x = torch.rand(2, 4, 3, 224, 224)
preds = hm(x)
preds.shape

torch.Size([2, 80000])

In [31]:
B, S, D = features.shape
out = head(features)
out = decoder(out.mean(0))
out.shape

torch.Size([4, 1000])

In [18]:
layer_top_probs = torch.tensor([[1, 2], [4, 5]]) 
layer_bottom_probs = torch.tensor([[10, 20, 30], [100, 200, 300]])

# layer_top_probs = torch.rand(2, 1000)
# layer_bottom_probs = torch.rand(2, 80)

preds = torch.bmm(layer_top_probs.unsqueeze(2), layer_bottom_probs.unsqueeze(1)).flatten(start_dim=1)
preds.shape, preds

(torch.Size([2, 6]),
 tensor([[  10,   20,   30,   20,   40,   60],
         [ 400,  800, 1200,  500, 1000, 1500]]))

In [2]:
import torch
import numpy as np

def get_mrr(predictions_dict):
    ranks = np.asarray([get_rank(value) for key, value in predictions_dict.items()])
    return np.sum((1 / ranks)) / len(predictions_dict)

def get_rank(dict_value):
    prob = dict_value['prob']
    label = dict_value['label']
    idx = np.argsort(prob)[::-1]
    np.argmax(prob) == label
    rank_i = np.squeeze(np.where(idx == label)) + 1

    return rank_i

In [22]:
import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

# class ViT(nn.Module):
#     def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
#         super().__init__()
#         image_height, image_width = pair(image_size)
#         patch_height, patch_width = pair(patch_size)

#         assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

#         num_patches = (image_height // patch_height) * (image_width // patch_width)
#         patch_dim = channels * patch_height * patch_width
#         assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

#         self.to_patch_embedding = nn.Sequential(
#             Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
#             nn.Linear(patch_dim, dim),
#         )

#         self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
#         self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
#         self.dropout = nn.Dropout(emb_dropout)

#         self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

#         self.pool = pool
#         self.to_latent = nn.Identity()

#         self.mlp_head = nn.Sequential(
#             nn.LayerNorm(dim),
#             nn.Linear(dim, num_classes)
#         )

#     def forward(self, img):
#         x = self.to_patch_embedding(img)
#         b, n, _ = x.shape

#         cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
#         x = torch.cat((cls_tokens, x), dim=1)
#         x += self.pos_embedding[:, :(n + 1)]
#         x = self.dropout(x)

#         x = self.transformer(x)

#         x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

#         x = self.to_latent(x)
#         return self.mlp_head(x)
    
    
class ViTEncoder(nn.Module):
    
    def __init__(self, *, image_size, patch_size, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., flatten=True):
        super().__init__()
        
        self.image_size = image_size
        self.patch_size = patch_size
        self.dim = dim
        self.heads = heads
        self.dim_head = dim_head
        self.depth = depth
        self.flatten = flatten
        
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()


    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)
        if self.flatten:
            return torch.flatten(x, start_dim=1)
        return x
    

class ModularViT(nn.Module):
    def __init__(self, *, encoder, num_classes, mlp_dim, pool = 'cls', channels = 3, dropout = 0):
        super().__init__()

        self.encoder = encoder
        self.encoder.flatten = False
        self.decoder = Transformer(self.encoder.dim, self.encoder.depth, self.encoder.heads, self.encoder.dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(self.encoder.dim),
            nn.Linear(self.encoder.dim, num_classes)
        )
        
    def freeze_encoder(self):
        for param in self.encoder.parameters():
            param.requires_grad = False
    
    

    def forward(self, img):
        x = self.encoder(img)
        
        x = self.decoder(x)
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
        x = self.to_latent(x)
        return self.mlp_head(x)


In [32]:
enc = ViTEncoder(
    image_size = 224,
    patch_size = 32,
    dim = 128,
    depth = 3,
    heads = 16,
    dropout = 0.1,
    emb_dropout = 0.1,
    mlp_dim = 2048,
    flatten = True
)
# model = ModularViT(encoder=enc, num_classes=80000, pool = 'cls', channels = 3, dropout = 0.1, mlp_dim=2048)
# model.freeze_encoder()

img = torch.randn(4, 3, 224, 224)

enc_preds = enc(img)
# preds = model(img)

In [33]:
preds.shape, enc_preds.shape

(torch.Size([4, 80000]), torch.Size([4, 6400]))

In [34]:
r = torch.rand((4, 10, 32))

In [35]:
r.shape

torch.Size([4, 10, 32])

In [36]:
r[:, 0].shape

torch.Size([4, 32])