In [None]:
import os
import math
import torch.nn as nn
import torch
import inspect
from torchinfo import summary

In [None]:
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet18, ResNet18_Weights
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.transforms import transforms 

In [None]:
from dataclasses import dataclass
from torch.nn import functional as F
from PIL import Image
from torchvision.datasets import VOCDetection

In [None]:
import numpy as np
import albumentations as A
import cv2
from albumentations.pytorch import ToTensorV2
from TRUST_preprocessing import split_n_rotate
import json

In [None]:
device = torch.device(f'cuda:{torch.cuda.current_device()}') if torch.cuda.is_available() else 'cpu'
torch.set_default_device(device)

In [None]:
@dataclass
class TRUSTConfig:
    N: int = 100
    M: int = 100
    d: int = 8
    batch_size: int = 64
    num_epochs: int = 100
    num_decoders: int = 6
    n_heads: int = 4
    block_size: int = 8
    bias: int = False
    


config = TRUSTConfig()

<img src="resnet_with_fpn.png" alt="Alternative text" />

In [None]:
model = nn.Sequential(*list(resnet18().children())[:8])
model

In [None]:
model = resnet18()
model

In [None]:
class TRUSTDataset(Dataset):
    def __init__(self, config, root, anno_dir, img_dir):
        super().__init__()
        self.N = config.N
        self.M = config.M
        self.anno_dir = root+anno_dir
        self.img_dir = root+img_dir
        self.img_list = []
        with open(self.anno_dir) as f:
            self.anno_list = json.load(f)
        for img_anno in self.anno_list:
            self.img_list.append(cv2.imread(img_anno['filename']))  
        self.angles = [np.random.choice(np.arange(-45, 46), size=1000) for _ in range(len(self.anno_list))]

    def __len__(self):
        return len(self.anno_list)
    
    def __getitem__(self, index):
        img_anno = self.anno_list[index]
        img = self.img_list[index]
        # angle_prob = self.angle_probs[index]
        # normalized_prob = np.linalg.norm(angle_prob)
        # angle = np.random.choice()

        angle = np.random.randint(-45, 46)
        rotated_img, trust_anno = split_n_rotate(img, img_anno['col_separators'], img_anno['row_separators'], self.N, self.M, angle)

        return rotated_img, trust_anno

In [None]:
class ResNet_with_FPN(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.d = config.d
        self.resnet = resnet18()
        self.pyramid = []
        self.return_nodes = {
            'layer1.1.conv2': 'layer1', #P2
            'layer2.1.conv2': 'layer2', #P3
            'layer3.1.conv2': 'layer3', #P4
            'layer4.1.conv2': 'layer4'  #P5
        }
        self.feature_extractor = create_feature_extractor(self.resnet, return_nodes=self.return_nodes)
        self.conv1_list = nn.ModuleList([
            nn.Conv2d(64, 1, (1,1), 1),
            nn.Conv2d(128, 1, (1,1), 1),
            nn.Conv2d(256, 1, (1,1), 1),
            nn.Conv2d(512, 1, (1,1), 1),
        ])
        self.conv3_list = nn.ModuleList([nn.Conv2d(1, self.d, (3,3), 2) for _ in range(4)])
        
    def forward(self, x):
        B, C, H, W = x.shape
        intermediate_outputs = self.feature_extractor(x)
        pyramid = []
        outs = []
        for i in range(3, -1, -1):
            layer_name = 'layer'+str(i)
            conv1 = self.conv1_list[i]
            conv3 = self.conv3_list[i]
            pyramid[i] = conv1(intermediate_outputs[layer_name])
            if i<3:
                pyramid[i] += F.interpolate(pyramid[i+1], scale_factor=(2,2))
            outs[i] = conv3(pyramid[i])
        return outs
        

            


        

        
        

<img src="TRUST.jpg" alt="Alternative text" />

In [None]:
a = torch.Tensor([1,2,3])

In [None]:
a = torch.Tensor([[[1,2,3],[1,2,3]]])
b = a.unsqueeze(2).repeat(1,1,5,1)
b

In [None]:
class CrossAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_embd = config.d
        self.n_heads = config.n_heads
        self.block_size = config.block_size
        self.c_attn_q = nn.Linear(self.n_embd, self.n_embd, bias=config.bias)
        self.c_attn_kv = nn.Linear(self.n_embd, self.n_embd*2, bias=config.bias)
        self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=config.bias)
        # self.mask = torch.tril(torch.ones((self.block_size, self.block_size))).view(1, 1, self.block_size, self.block_size)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)

        

    def forward(self, x, pos_encoding, encoder_output):
        assert x.shape[0] == encoder_output.shape[0]
        B, T, C = encoder_output.shape #batch_size, block_size, n_embd
        B, N, C = x.shape #batch_size, n_queries, n_embd
        print(f"T:{T} N:{N}")
        k0, v = self.c_attn_kv(encoder_output).split(self.n_embd, dim=2)
        q0 = self.c_attn_q(x)
        k = k0 + pos_encoding
        q = q0

        q = q.view(B, N, self.n_heads, C//self.n_heads).transpose(1,2) # (B, n_heads, N, h_size)
        k = k.view(B, T, self.n_heads, C//self.n_heads).transpose(1,2) # (B, n_heads, T, h_size)
        v = v.view(B, T, self.n_heads, C//self.n_heads).transpose(1,2)

        att = q@k.transpose(-2,-1)*(1/math.sqrt(k.size(-1))) # (B, n_heads, N, T)
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        y = att@v
        y = y.transpose(1,2).contiguous().view(B,N,C) #(B, n_heads, N, h_size)
        y = self.resid_dropout(self.c_proj(y))
        return y

class SelfAttention(nn.Module):
    def __init__(self, config, masked=False):
        super().__init__()
        self.n_embd = config.d
        self.n_heads = config.n_heads
        self.block_size = config.block_size
        self.c_attn = nn.Linear(self.n_embd, self.n_embd*3, bias=config.bias)
        self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=config.bias)
        self.mask = torch.tril(torch.ones((self.block_size, self.block_size))).view(1, 1, self.block_size, self.block_size)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.masked = masked
        

    def forward(self, x, pos_encoding):
        B, T, C = x.shape #batch_size, n_embd, block_size
        q0, k0, v = self.c_attn(x).split(self.n_embd, dim=2) 
        q = q0 + pos_encoding
        k = k0 + pos_encoding

        q = q.view(B, T, self.n_heads, C//self.n_heads).transpose(1,2) # (B, n_heads, T, h_size)
        k = k.view(B, T, self.n_heads, C//self.n_heads).transpose(1,2) 
        v = v.view(B, T, self.n_heads, C//self.n_heads).transpose(1,2)

        att = q@k.transpose(-2,-1)*(1/math.sqrt(k.size(-1))) # (B, n_heads, T, T)
        if self.masked:
            att = att.masked_fill_(self.mask[:, :, :T, :T] == 0, -float('inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        y = att@v
        y = y.transpose(1,2).contiguous().view(B,T,C) #(B, n_heads, T, h_size)
        y = self.resid_dropout(self.c_proj(y))
        return y

class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.gelu    = nn.GELU()
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

#query-based splitting module
class DecoderBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.d = config.d
        self.self_attention = SelfAttention(config, masked=True)
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.cross_attention = CrossAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)
        self.ln_3 = nn.LayerNorm(config.n_embd)
    def forward(self, prev_output, pos_encoding, encoder_output):
        x = self.self_attention(prev_output, pos_encoding)
        x = self.ln_1(x)
        x += self.cross_attention(x, pos_encoding, encoder_output)
        x = self.ln_2(x)
        x += self.mlp(x)
        x = self.ln_3(x)
        return x
# class Decoder(nn.Module): #transformer decoder
#     def __init__(self, config):
#         super().__init__()
#         self.d = config.d
#     def forward(self, x):






class TRUST(nn.Module):
    def __init__(self, config):
        self.d = config.d
        self.N = config.N
        self.M = config.M
        self.num_decoders = config.num_decoders
        self.bias = config.bias
        self.visual_features = ResNet_with_FPN(config)
        self.row_embedding = nn.Embedding(self.N, self.d)
        self.col_embedding = nn.Embedding(self.M, self.d)
        self.row_decoder = nn.ModuleList([DecoderBlock(config) for _ in range(self.num_decoders)])
        self.col_decoder = nn.ModuleList([DecoderBlock(config) for _ in range(self.num_decoders)])
        self.row_separator_decoder = nn.ModuleList([DecoderBlock(config) for _ in range(self.num_decoders)])
        self.col_separator_decoder = nn.ModuleList([DecoderBlock(config) for _ in range(self.num_decoders)])
        self.row_fc = nn.Linear(self.d, 3, bias=self.bias)
        self.col_fc = nn.Linear(self.d, 3, bias=self.bias)
        self.vertex_fc = nn.Linear(self.d, 4, bias=self.bias)

    def forward(self, x, target=None):
        B, _, _, _ = x.shape
        x = self.visual_features(x)
        P2_unflattened = x[0]
        P2 = P2_unflattened.flatten(start_dim=2)
        P2 = P2.transpose(1,2)
        row_queries = self.row_embedding(torch.arange(self.N).repeat(B,1))
        row_features = row_queries
        col_queries = self.col_embedding(torch.arange(self.M).repeat(B,1))
        col_features = col_queries
        for block in self.row_decoder:
            row_features = block(row_features, row_queries, P2) #B, N, d
        for block in self.col_decoder:
            col_features = block(col_features, col_queries, P2) #B, M, d
        row_separators = self.row_fc(row_features)
        enhanced_row_separators = row_separators
        col_separators = self.col_fc(col_features)
        enhanced_col_separators = col_separators
        for block in self.row_separator_decoder:
            enhanced_row_separators = block(enhanced_row_separators, row_separators, col_separators)
        for block in self.col_separator_decoder:
            enhanced_col_separators = block(enhanced_col_separators, col_separators, row_separators)
        enhanced_row_separators = enhanced_row_separators.unsqueeze(2).repeat(1,1,self.M,1)
        enhanced_col_separators = enhanced_col_separators.unsqueeze(1).repeat(1,self.N,1,1)
        vertex_features = enhanced_row_separators + enhanced_col_separators
        merge_features = self.vertex_fc(vertex_features)
        row_cls_criterion = nn.BCELoss()
        col_cls_criterion = nn.BCELoss()
        link_cls_criterion = nn.BCELoss()
        angle_cls_criterion = nn.CrossEntropyLoss()
        start_point_criterion = nn.SmoothL1Loss(beta=1)

        if target is not None:
            #online hard example mining
            
            

        
    
        

        
        



        

        