In [1]:
import os
import glob
import numpy as np
import pandas as pd
import random
import math
import gc
import cv2
from tqdm import tqdm
import time
from functools import lru_cache
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import timm
from timm.scheduler import CosineLRScheduler
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
from sklearn.metrics import matthews_corrcoef
from sklearn.model_selection import train_test_split
import time
import io
import os
import sys
import time
import json
import math
import glob
import datetime
import argparse
import numpy as np
from pathlib import Path
from collections import defaultdict, deque

import torch
import torchvision
import torch.distributed as dist
from torch import nn, einsum
from torch.nn import functional as F
import torch.backends.cudnn as cudnn
import torch.utils.checkpoint as checkpoint

from typing import Iterable, Optional
from timm.models import create_model
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
from timm.data import Mixup,create_transform
from timm.models.registry import register_model
from timm.models.layers import DropPath, trunc_normal_
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.utils import accuracy, ModelEma,NativeScaler, get_state_dict, ModelEma

from torchvision import datasets, transforms
from torchvision.datasets.folder import ImageFolder, default_loader

from functools import partial
from einops import rearrange

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

# utils

In [3]:
@torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [
        torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
    ]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output

# build

In [4]:
class E_MHSA(nn.Module):
    """
    Efficient Multi-Head Self Attention
    """
    def __init__(self, dim, out_dim=None, head_dim=32, qkv_bias=True, qk_scale=None,
                 attn_drop=0, proj_drop=0., sr_ratio=1):
        super().__init__()
        self.dim = dim
        self.out_dim = out_dim if out_dim is not None else dim
        self.num_heads = self.dim // head_dim
        self.scale = qk_scale or head_dim ** -0.5
        self.q = nn.Linear(dim, self.dim, bias=qkv_bias)
        self.k = nn.Linear(dim, self.dim, bias=qkv_bias)
        self.v = nn.Linear(dim, self.dim, bias=qkv_bias)
        self.proj = nn.Linear(self.dim, self.out_dim)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)

        self.sr_ratio = sr_ratio
        self.N_ratio = sr_ratio ** 2
        if sr_ratio > 1:
            self.sr = nn.AvgPool1d(kernel_size=self.N_ratio, stride=self.N_ratio)
            self.norm = nn.BatchNorm1d(dim, eps=NORM_EPS)
        self.is_bn_merged = False

    def merge_bn(self, pre_bn):
        merge_pre_bn(self.q, pre_bn)
        if self.sr_ratio > 1:
            merge_pre_bn(self.k, pre_bn, self.norm)
            merge_pre_bn(self.v, pre_bn, self.norm)
        else:
            merge_pre_bn(self.k, pre_bn)
            merge_pre_bn(self.v, pre_bn)
        self.is_bn_merged = True

    def forward(self, x):
        B, N, C = x.shape
        q = self.q(x)
        q = q.reshape(B, N, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)

        if self.sr_ratio > 1:
            x_ = x.transpose(1, 2)
            x_ = self.sr(x_)
            if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged:
                x_ = self.norm(x_)
            x_ = x_.transpose(1, 2)
            k = self.k(x_)
            k = k.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 3, 1)
            v = self.v(x_)
            v = v.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)
        else:
            k = self.k(x)
            k = k.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 3, 1)
            v = self.v(x)
            v = v.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)
        attn = (q @ k) * self.scale

        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

In [44]:
NORM_EPS = 1e-5
class PatchEmbed(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 stride=1,
                 mode = "V"
                ):
        super(PatchEmbed, self).__init__()
        norm_layer = partial(nn.BatchNorm2d, eps=NORM_EPS)
        if stride == 4 and mode == "V":
            self.avgpool = nn.AvgPool2d((4, 32), stride=4, ceil_mode=True, count_include_pad=False)
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
            self.norm = norm_layer(out_channels)
        elif stride == 2 and mode == "V":
            self.avgpool = nn.AvgPool2d((2, 16), stride=2, ceil_mode=True, count_include_pad=False)
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
            self.norm = norm_layer(out_channels)
        elif stride == 2 and mode == "H":
            self.avgpool = nn.AvgPool2d((64, 2), stride=2, ceil_mode=True, count_include_pad=False)
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
            self.norm = norm_layer(out_channels)
        elif stride == 1 and mode == "H":
            self.avgpool = nn.AvgPool2d((32, 1), stride=1, ceil_mode=True, count_include_pad=False)
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
            self.norm = norm_layer(out_channels)
        elif in_channels != out_channels:
            self.avgpool = nn.Identity()
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
            self.norm = norm_layer(out_channels)
        else:
            self.avgpool = nn.Identity()
            self.conv = nn.Identity()
            self.norm = nn.Identity()

    def forward(self, x):
        return self.norm(self.conv(self.avgpool(x)))

In [204]:
class E_PQKV(nn.Module):
    """
    product QKV
    """
    def __init__(self, dim, head_dim=32, out_dim=None, qkv_bias=True, qk_scale=None,
                 attn_drop=0, proj_drop=0., sr_ratio=1):
        super().__init__()
        self.dim = dim
        self.out_dim = out_dim if out_dim is not None else dim
        self.num_heads = self.dim // head_dim
        self.scale = qk_scale or head_dim ** -0.5
        self.q = nn.Linear(dim, self.dim, bias=qkv_bias)
        self.k = nn.Linear(dim, self.dim, bias=qkv_bias)
        self.v = nn.Linear(dim, self.dim, bias=qkv_bias)
    def forward(self, x):
        B, N, C = x.shape
        q = self.q(x)
        q = q.reshape(B, N, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)
        k = self.k(x)
        k = k.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 3, 1)
        v = self.v(x)
        v = v.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)
        return q,k,v
    

In [211]:
class E_CA(nn.Module):
    """
    product cross attention
    """
    def __init__(self, dim,out_dim,qk_scale=None, attn_drop=0, proj_drop=0.):
        super().__init__()
        self.scale = qk_scale 
        self.proj = nn.Linear(dim,out_dim)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)
    def forward(self, q,k,v,b,n,c):
        attn = (q @ k) * self.scale

        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        attn = (attn @ v).transpose(1, 2).reshape(b,n,c)
        attn = self.proj(attn)
        attn = self.proj_drop(attn)
        return attn

In [252]:
class E_MTDA(nn.Module):
    """
    Efficient Multi-Head Three-Dimensional Attention
    x_mh,x_ah,x_mv,x_av
    """
    def __init__(self, dim, head_dim=32, out_dim=None, qkv_bias=True, qk_scale=None,
                 attn_drop=0, proj_drop=0., sr_ratio=1):
        super().__init__()
        self.dim = dim
        self.out_dim = out_dim if out_dim is not None else dim
        self.num_heads = self.dim // head_dim
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = E_PQKV( dim, head_dim, out_dim, qkv_bias, qk_scale,attn_drop, proj_drop, sr_ratio)
        self.ca = E_CA(self.dim,self.out_dim,self.scale,attn_drop,proj_drop)
        self.proj = nn.Linear(self.dim, self.out_dim)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)

        self.sr_ratio = sr_ratio
        self.N_ratio = sr_ratio ** 2
        if sr_ratio > 1:
            self.sr = nn.AvgPool1d(kernel_size=self.N_ratio, stride=self.N_ratio)
            self.norm = nn.BatchNorm1d(dim, eps=NORM_EPS)
        self.is_bn_merged = False

    def merge_bn(self, pre_bn):
        merge_pre_bn(self.q, pre_bn)
        if self.sr_ratio > 1:
            merge_pre_bn(self.k, pre_bn, self.norm)
            merge_pre_bn(self.v, pre_bn, self.norm)
        else:
            merge_pre_bn(self.k, pre_bn)
            merge_pre_bn(self.v, pre_bn)
        self.is_bn_merged = True
    # def pro_pkv(x_mh)

    def forward(self, x_mh,x_ah,x_mv,x_av):
        B, N, C = x_mh.shape
        q_mh,k_mh,v_mh = self.qkv(x_mh)
        q_mv,k_mv,v_mv = self.qkv(x_mv)
        q_ah,k_ah,v_ah = self.qkv(x_ah)
        q_av,k_av,v_av = self.qkv(x_av)
        attn_mhav=self.ca(q_mh,k_av,v_mh, B, N, C)
        attn_ahmv=self.ca(q_ah,k_mv,v_ah, B, N, C)
        attn_mvah=self.ca(q_mv,k_ah,v_mv, B, N, C)
        attn_avmh=self.ca(q_av,k_mh,v_av, B, N, C)

        attn_mhav+=x_mh
        attn_ahmv+=x_ah
        attn_mvah+=x_mv
        attn_avmh+=x_av
        
        attn_mhav = attn_mhav.mean(1)
        attn_ahmv = attn_ahmv.mean(1)
        attn_mvah = attn_mvah.mean(1)
        attn_avmh = attn_avmh.mean(1)
        x = torch.concat([attn_mhav,attn_ahmv,attn_mvah,attn_avmh],dim=1)
        return x

In [257]:
class ShardExamine(nn.Module):
    
    def __init__(self,dim,
                 in_channels,
                 out_channels,
                 num_head=64, 
                 qkv_bias=False, 
                 attn_drop=0.1, 
                 proj_drop=0.1,
                 stride=1
                ):
        super().__init__()
        # super(LocalCoccurrence, self).__init__()
        self.PEV = PatchEmbed(in_channels,out_channels,stride*2,mode = "V")
        self.PEH = PatchEmbed(in_channels,out_channels,stride,mode = "H")
        self.norm1 = nn.LayerNorm(dim)
        self.attn  = E_MTDA(dim,num_head)

    def forward(self, u_m, u_a):
        u_mv = self.PEV(u_m)
        u_av = self.PEV(u_a)
        u_m = self.PEH(u_m)
        u_a = self.PEH(u_a)
        B,C,H,W = u_mv.shape
        L = H*W
        dim = C

        u_mv = u_mv.reshape(B,dim,L).permute(0,2,1)
        u_av = u_av.reshape(B,dim,L).permute(0,2,1)
        u_m  = u_m.reshape(B,dim,L).permute(0,2,1)
        u_a  = u_a.reshape(B,dim,L).permute(0,2,1)

        u_mv = self.norm1(u_mv)
        u_av = self.norm1(u_av)
        u_m  = self.norm1(u_m)
        u_a  = self.norm1(u_a)
        x = self.attn(u_m, u_a,u_mv, u_av)# x_mh,x_ah,x_mv,x_av
        # x_m, x_a = self.attn(u_m, u_a)
        # x_mv, x_av = self.attn(u_mv, u_a)
        # gap_m = x_m.mean(1)
        # gap_a = x_a.mean(1)
        
        return  x

In [170]:
class basNet(nn.Module):
    """
    only two viwe
    """
    def load_pretrain(self, ):
        return

    def __init__(self,K=8192, m=0.999, T=0.07,dim=128,nc=512):
        super(basNet, self).__init__()
        # self.output_type = ['inference', 'loss']

        self.K = K
        self.m = m
        self.T = T
        self.encoder_q = timm.create_model ('efficientnetv2_m',
                                          pretrained=False, 
                                          drop_rate = 0.2, 
                                          drop_path_rate = 0.1,
                                          num_classes=nc
                                         )
        self.encoder_k = timm.create_model ('efficientnetv2_m',
                                          pretrained=False, 
                                          drop_rate = 0.2, 
                                          drop_path_rate = 0.1,
                                          num_classes=nc
                                         )
        
        
        self.att = ShardExamine(64,304,64)
        # dim_mlp = self.encoder_q.fc.weight.shape[1]
        self.encoder_q.fc = nn.Sequential(
            nn.LayerNorm(512),
            nn.GELU(),
            nn.Linear(512, 128),
        )
        self.encoder_k.fc =nn.Sequential(
            nn.LayerNorm(512),
            nn.GELU(),
            nn.Linear(512, 128),
        )
        for param_q, param_k in zip(
            self.encoder_q.parameters(), self.encoder_k.parameters()
        ):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        # create the queue
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = nn.functional.normalize(self.queue, dim=0)

        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

        # dim = 1280

        # self.lc  = LocalCoccurrence(dim)
        # self.gl  = GlobalConsistency(dim)
        self.mlp = nn.Sequential(
            nn.LayerNorm(1024),
            nn.Linear(1024, 512),
            nn.GELU(),
            nn.Linear(512, 128),
        )#<todo> mlp needs to be deep if backbone is strong?
        self.cancer = nn.Linear(128,1)
        # self.labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
        
    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """
        Momentum update of the key encoder
        """
        for param_q, param_k in zip(
            self.encoder_q.parameters(), self.encoder_k.parameters()
        ):
            param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)
            
    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        # gather keys before updating queue
        # keys = concat_all_gather(keys)

        batch_size = keys.shape[0]

        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.queue[:, ptr : ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.K  # move pointer

        self.queue_ptr[0] = ptr
        
    @torch.no_grad()
    def _batch_shuffle_ddp(self, x):
        """
        Batch shuffle, for making use of BatchNorm.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # random shuffle index
        idx_shuffle = torch.randperm(batch_size_all).cuda()

        # broadcast to all gpus
        torch.distributed.broadcast(idx_shuffle, src=0)

        # index for restoring
        idx_unshuffle = torch.argsort(idx_shuffle)

        # shuffled index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this], idx_unshuffle

    @torch.no_grad()
    def _batch_unshuffle_ddp(self, x, idx_unshuffle):
        """
        Undo batch shuffle.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # restored index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this]
    
    
    def forward(self, x):
        """
        Input:
            im_q: a batch of query images
            im_k: a batch of key images
        Output:
            logits, targets
        """
        features=[]
        def hook(module, input, output):
            features.append(input)
            return None
        self.encoder_q.blocks[4].register_forward_hook(hook)
        self.encoder_q.blocks[5].register_forward_hook(hook)
        self.encoder_q.blocks[6].register_forward_hook(hook)
        batch_size,C,H,W = x.shape
        x = x.reshape(-1, C, H, W)
        x_m =torch.tensor( np.array_split(batch,2,axis=3)[0])

        # print(x_m.shape)
        x_a = torch.tensor( np.array_split(batch,2,axis=3)[1])
        
        
        # compute query features
        q_m = self.encoder_q(x_m)  # queries: NxC
        q_a = self.encoder_q(x_a)  # queries: NxC
        q_m = nn.functional.normalize(q_m, dim=1)
        q_a = nn.functional.normalize(q_a, dim=1)
        
        
        
        last = torch.cat([q_a, q_m ],-1)
        last = self.mlp(last)
        
        cancer = self.cancer(last).reshape(-1)
        cancer = torch.sigmoid(cancer)
        q_m = self.encoder_q.fc(q_m)
        q_a = self.encoder_q.fc(q_a)
        # compute key features
        with torch.no_grad():  # no gradient to keys
            self._momentum_update_key_encoder()  # update the key encoder
            k_m = self.encoder_k(x_m)  # keys: NxC
            k_m = nn.functional.normalize(k_m, dim=1)
            k_m = self.encoder_k.fc(k_m) 
            k_a = self.encoder_k(x_a)  # keys: NxC
            k_a = nn.functional.normalize(k_a, dim=1)
            k_a = self.encoder_k.fc(k_a)

        # compute logits
        # Einstein sum is more intuitive
        # positive logits: Nx1
        l_pos1 = torch.einsum("nc,nc->n", [q_m, k_m]).unsqueeze(-1)
        l_pos2 = torch.einsum("nc,nc->n", [q_m, k_a]).unsqueeze(-1)
        l_pos3 = torch.einsum("nc,nc->n", [q_a, k_m]).unsqueeze(-1)
        l_pos4 = torch.einsum("nc,nc->n", [q_a, k_a]).unsqueeze(-1)
        # negative logits: NxK
        l_neg_m = torch.einsum("nc,ck->nk", [q_m, self.queue.clone().detach()])
        l_neg_a = torch.einsum("nc,ck->nk", [q_a, self.queue.clone().detach()])
        # logits: Nx(1+K)
        logits = torch.cat([l_pos1,l_pos2,l_pos3,l_pos4, l_neg_m,l_neg_a], dim=1)

        # apply temperature
        logits /= self.T

        # labels: positive key indicators
        labels = torch.zeros(logits.shape[1], dtype=torch.long).cuda()
        labels[1],labels[2],labels[3]=1,2,3
        
        
        # dequeue and enqueue
        self._dequeue_and_enqueue(k_m)
        self._dequeue_and_enqueue(k_a)
        
        
        
        return cancer,logits, labels,features

    


### torch.Size([1, 160, 64, 32])
### torch.Size([1, 176, 64, 32])
### torch.Size([1, 304, 32, 16])

In [259]:
net = ShardExamine(64,304,64).to(device)

In [249]:
i_m = np.random.random(size=(2, 304, 32, 16))
i_a = np.random.random(size=(2, 304, 32, 16))
i_m = torch.tensor(i_m).to(device).to(torch.float32)
i_a = torch.tensor(i_a).to(device).to(torch.float32)
i_a.shape

torch.Size([2, 304, 32, 16])

In [260]:
e= net(i_m,i_a)

In [261]:
e.shape

torch.Size([2, 256])

In [251]:
a.shape

torch.Size([2, 16, 64])

In [245]:
b.shape


torch.Size([2, 64])

In [124]:
b1.squeeze(0).shape

torch.Size([1, 16, 64])

In [85]:
pe = PatchEmbed(160,32,1,"H").to(device)

In [88]:
xx = pe(i_a).to(device)
xx.shape

torch.Size([1, 32, 1, 16])

In [89]:
B,C,H,W = xx.shape
L = H*W
dim = C

In [90]:
xx = xx.reshape(B,dim,L).permute(0,2,1)
xx.shape

torch.Size([1, 16, 32])

In [None]:
nn.LayerNorm(dim)