In [1]:
import os
import glob
import numpy as np
import pandas as pd
import random
import math
import gc
import cv2
from tqdm import tqdm
import time
from functools import lru_cache
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import timm
from timm.scheduler import CosineLRScheduler
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
from sklearn.metrics import matthews_corrcoef
from sklearn.model_selection import train_test_split
import time
import io
import os
import sys
import time
import json
import math
import glob
import datetime
import argparse
import numpy as np
from pathlib import Path
from collections import defaultdict, deque

import torch
import torchvision
import torch.distributed as dist
from torch import nn, einsum
from torch.nn import functional as F
import torch.backends.cudnn as cudnn
import torch.utils.checkpoint as checkpoint

from typing import Iterable, Optional
from timm.models import create_model
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
from timm.data import Mixup,create_transform
from timm.models.registry import register_model
from timm.models.layers import DropPath, trunc_normal_
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.utils import accuracy, ModelEma,NativeScaler, get_state_dict, ModelEma

from torchvision import datasets, transforms
from torchvision.datasets.folder import ImageFolder, default_loader

from functools import partial
from einops import rearrange

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [None]:
class PatchEmbed(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 stride=1):
        super(PatchEmbed, self).__init__()
        norm_layer = partial(nn.BatchNorm2d, eps=NORM_EPS)
        if stride == 2:
            self.avgpool = nn.AvgPool2d((2, 2), stride=2, ceil_mode=True, count_include_pad=False)
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
            self.norm = norm_layer(out_channels)
        elif in_channels != out_channels:
            self.avgpool = nn.Identity()
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
            self.norm = norm_layer(out_channels)
        else:
            self.avgpool = nn.Identity()
            self.conv = nn.Identity()
            self.norm = nn.Identity()

    def forward(self, x):
        return self.norm(self.conv(self.avgpool(x)))

In [None]:
def similiarity(x1,x2):
	p12 = (x1*x2).sum(-1)
	p1 = torch.sqrt((x1*x1).sum(-1))
	p2 = torch.sqrt((x2*x2).sum(-1))
	s = p12/(p1*p2+1e-6)
	return s


def criterion_global_consistency(g_m, p_m, g_a, p_a):
	loss =  -0.5*(similiarity(g_m, p_m)+similiarity(g_a, p_a))
	loss = loss.mean()
	return loss


In [20]:
model = timm.create_model ('efficientnetv2_m',pretrained=False, drop_rate = 0.2, drop_path_rate = 0.1)

In [21]:
for name , model1 in model.named_children():
    if name == "blocks":
        for name1 , model2 in model1.named_children():
            print(name1)

0
1
2
3
4
5
6


In [5]:
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

In [3]:
class Backbone(nn.Module):
    def __init__(self, ):
        super(Backbone, self).__init__()
        self.register_buffer('mean', torch.FloatTensor([0.5, 0.5, 0.5]).reshape(1,3,1,1))
        self.register_buffer('std',  torch.FloatTensor([0.5, 0.5, 0.5]).reshape(1,3,1,1))
        self.encoder = timm.create_model ('efficientnetv2_m',pretrained=False, drop_rate = 0.2, drop_path_rate = 0.1)

    def forward(self, x):
        features=[]
        def hook(module, input, output):
            features.append(input)
            return None
        x = (x - self.mean) / self.std
        self.encoder.blocks[4].register_forward_hook(hook)
        self.encoder.blocks[5].register_forward_hook(hook)
        self.encoder.blocks[6].register_forward_hook(hook)
        x = self.encoder.forward_features(x)
        return x,features

In [7]:
class E_MHSA(nn.Module):
    """
    Efficient Multi-Head Self Attention
    """
    def __init__(self, dim, out_dim=None, head_dim=32, qkv_bias=True, qk_scale=None,
                 attn_drop=0, proj_drop=0., sr_ratio=1):
        super().__init__()
        self.dim = dim
        self.out_dim = out_dim if out_dim is not None else dim
        self.num_heads = self.dim // head_dim
        self.scale = qk_scale or head_dim ** -0.5
        self.q = nn.Linear(dim, self.dim, bias=qkv_bias)
        self.k = nn.Linear(dim, self.dim, bias=qkv_bias)
        self.v = nn.Linear(dim, self.dim, bias=qkv_bias)
        self.proj = nn.Linear(self.dim, self.out_dim)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)

        self.sr_ratio = sr_ratio
        self.N_ratio = sr_ratio ** 2
        if sr_ratio > 1:
            self.sr = nn.AvgPool1d(kernel_size=self.N_ratio, stride=self.N_ratio)
            self.norm = nn.BatchNorm1d(dim, eps=NORM_EPS)
        self.is_bn_merged = False

    def merge_bn(self, pre_bn):
        merge_pre_bn(self.q, pre_bn)
        if self.sr_ratio > 1:
            merge_pre_bn(self.k, pre_bn, self.norm)
            merge_pre_bn(self.v, pre_bn, self.norm)
        else:
            merge_pre_bn(self.k, pre_bn)
            merge_pre_bn(self.v, pre_bn)
        self.is_bn_merged = True

    def forward(self, x):
        B, N, C = x.shape
        q = self.q(x)
        q = q.reshape(B, N, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)

        if self.sr_ratio > 1:
            x_ = x.transpose(1, 2)
            x_ = self.sr(x_)
            if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged:
                x_ = self.norm(x_)
            x_ = x_.transpose(1, 2)
            k = self.k(x_)
            k = k.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 3, 1)
            v = self.v(x_)
            v = v.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)
        else:
            k = self.k(x)
            k = k.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 3, 1)
            v = self.v(x)
            v = v.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)
        attn = (q @ k) * self.scale

        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

In [8]:
class GlobalConsistency(nn.Module):
    def __init__(self,dim):
        super(GlobalConsistency, self).__init__()
        self.pj = nn.Linear(dim,128)
        self.project = nn.Linear(dim,dim) #<todo> try mlp?

    def forward(self, u_m, u_a):
        B, C, H, W = u_m.shape

        g_m = F.adaptive_max_pool2d(u_m,1)
        g_a = F.adaptive_max_pool2d(u_a,1)
        g_m = torch.flatten(g_m, 1)
        g_a = torch.flatten(g_a, 1)
        pj_m = self.pj(g_m)
        pj_a = self.pj(g_a)
        p_a = self.project(g_m)
        p_m = self.project(g_a)

        return g_m, p_m, g_a, p_a ,pj_m,pj_a

In [11]:
class CrossAttention(nn.Module):
	def __init__(self, dim, num_head=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
		super(CrossAttention, self).__init__()

		assert dim % num_head == 0, 'dim should be divisible by num_heads'
		self.num_head = num_head
		head_dim = dim // num_head
		self.scale = head_dim ** (-0.5)

		self.qkv_a = nn.Linear(dim, dim * 3, bias=qkv_bias)
		self.qkv_m = nn.Linear(dim, dim * 3, bias=qkv_bias)
		self.proj_a = nn.Linear(dim, dim)
		self.proj_m = nn.Linear(dim, dim)

		self.attn_drop = nn.Dropout(attn_drop)
		self.proj_drop = nn.Dropout(proj_drop)

	def forward(self, u_m, u_a):
		B,L,dim = u_m.shape

		qkv_m = self.qkv_m(u_m)
		qkv_m = qkv_m.reshape(B, L, 3, self.num_head, dim // self.num_head).permute(2, 0, 3, 1, 4)
		q_m, k_m, v_m = qkv_m.unbind(0)

		qkv_a = self.qkv_m(u_a)
		qkv_a = qkv_a.reshape(B, L, 3, self.num_head, dim // self.num_head).permute(2, 0, 3, 1, 4)
		q_a, k_a, v_a = qkv_a.unbind(0)

		attn_m = (q_m @ k_a.transpose(-2, -1)) * self.scale
		attn_m = attn_m.softmax(dim=-1)
		attn_m = self.attn_drop(attn_m)

		attn_a = (q_a @ k_m.transpose(-2, -1)) * self.scale
		attn_a = attn_a.softmax(dim=-1)
		attn_a = self.attn_drop(attn_a)


		x_m = (attn_m @ v_m).transpose(1, 2).reshape(B, L, dim)
		x_m = self.proj_m(x_m)
		x_m = self.proj_drop(x_m)

		x_a = (attn_a @ v_a).transpose(1, 2).reshape(B, L, dim)
		x_a = self.proj_a(x_a)
		x_a = self.proj_drop(x_a)

		return  x_m, x_a

In [12]:
class LocalCoccurrence(nn.Module):
	def __init__(self,dim, num_head=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
		super(LocalCoccurrence, self).__init__()

		self.norm1 = nn.LayerNorm(dim)
		self.attn  = CrossAttention(dim, num_head, qkv_bias, attn_drop, proj_drop)

	def forward(self, u_m, u_a):
		B,C,H,W = u_m.shape
		L = H*W
		dim = C

		u_m = u_m.reshape(B,dim,L).permute(0,2,1)
		u_a = u_a.reshape(B,dim,L).permute(0,2,1)

		x_m = self.norm1(u_m)
		x_a = self.norm1(u_a)
		x_m, x_a = self.attn(x_m, x_a)

		gap_m = x_m.mean(1)
		gap_a = x_a.mean(1)
		return gap_m, gap_a

In [13]:
class Net(nn.Module):
    def load_pretrain(self, ):
        return

    def __init__(self,):
        super(Net, self).__init__()
        self.output_type = ['inference', 'loss']

        self.backbone_m = Backbone()
        self.backbone_a = Backbone()
        dim = 1792

        self.lc  = LocalCoccurrence(dim)
        self.gl  = GlobalConsistency(dim)
        self.mlp = nn.Sequential(
            nn.LayerNorm(dim*3),
            nn.Linear(dim*3, dim),
            nn.GELU(),
            nn.Linear(dim, dim),
        )#<todo> mlp needs to be deep if backbone is strong?
        self.cancer = nn.Linear(dim,1)

    def forward(self, batch):
        batch_size,C,H,W = batch.shape
        x = batch.reshape(-1, C, H, W)
        
        x_m =torch.tensor( np.array_split(batch,2,axis=3)[0])

        print(x_m.shape)
        x_a = torch.tensor( np.array_split(batch,2,axis=3)[1])
        x_m=x_m.to(torch.float32)
        x_a=x_a.to(torch.float32)
        u_m = self.backbone_m(x_m)
        u_a = self.backbone_a(x_a)
        # u_m=transforms.ToTensor()(u_m)
        # u_a=transforms.ToTensor()(u_a)
        
        gap_m, gap_a = self.lc(u_m, u_a)

        g_m, p_m, g_a, p_a = self.gl(u_m, u_a)
        gp_m = g_m + p_m

        last = torch.cat([gp_m, gap_m, gap_a ],-1)
        last = self.mlp(last)
        cancer = self.cancer(last).reshape(-1)


#         output = {}
#         if  'loss' in self.output_type:
#             output['cancer_loss'] = F.binary_cross_entropy_with_logits(cancer, batch['cancer'])
#             output['global_loss'] = criterion_global_consistency(g_m, p_m, g_a, p_a)


#         if 'inference' in self.output_type:
#             output['cancer'] = torch.sigmoid(cancer)

        return torch.sigmoid(cancer)

In [9]:
m = Backbone()

In [16]:
i = np.random.random(size=(1,3,1024,1024))
i.shape

(1, 3, 1024, 512)

In [None]:
# np.concatenate((np.array_split(i,2,axis=3)[0],np.array_split(i,2,axis=3)[1]),axis=0)

In [None]:
# np.array_split(i,2,axis=3)[0]

In [4]:
net=Backbone()

In [17]:
net.to(device)
# i.to(device)
i = torch.tensor( i).to(device).to(torch.float32)

In [18]:
x,x1 = net(i)

In [19]:
for i in x1:
    print(i[0].shape)

torch.Size([1, 160, 64, 32])
torch.Size([1, 176, 64, 32])
torch.Size([1, 304, 32, 16])


In [30]:

class LayerActivations:
    features = None
 
    def __init__(self, model, layer_num):
        self.hook = model[layer_num].register_forward_hook(self.hook_fn)
 
    def hook_fn(self, module, input, output):
        self.features = output.cpu()
 
    def remove(self):
        self.hook.remove()


AttributeError: 'Backbone' object has no attribute 'Backbone'

In [5]:

from torchsummary import summary

In [8]:

# model=ESPNet_Encoder()
summary(net, (3, 1024, 512))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
        Conv2dSame-1         [-1, 24, 512, 256]             648
          Identity-2         [-1, 24, 512, 256]               0
              SiLU-3         [-1, 24, 512, 256]               0
    BatchNormAct2d-4         [-1, 24, 512, 256]              48
            Conv2d-5         [-1, 24, 512, 256]           5,184
          Identity-6         [-1, 24, 512, 256]               0
              SiLU-7         [-1, 24, 512, 256]               0
    BatchNormAct2d-8         [-1, 24, 512, 256]              48
          Identity-9         [-1, 24, 512, 256]               0
        ConvBnAct-10         [-1, 24, 512, 256]               0
           Conv2d-11         [-1, 24, 512, 256]           5,184
         Identity-12         [-1, 24, 512, 256]               0
             SiLU-13         [-1, 24, 512, 256]               0
   BatchNormAct2d-14         [-1, 24, 5

In [96]:
import netron
import torch.onnx
from torch.autograd import Variable

In [97]:
torch.save(net.state_dict(), './save.pt')

In [98]:
modelData = "./save.pt"  # 定义模型数据保存的路径

In [99]:
torch.onnx.export(net, x, modelData)

RuntimeError: The size of tensor a (1792) must match the size of tensor b (3) at non-singleton dimension 1

In [None]:
netron.start(modelData)

In [None]:
def build_dataset(is_train, args):
    transform = build_transform(is_train, args)
    if not args.use_mcloader:
        root = os.path.join(args.data_path, 'train' if is_train else 'val')
        print(f"/n-----------------/n{root}")
        dataset = datasets.ImageFolder(root, transform=transform)
        print(dataset)
    else:
        from mcloader import ClassificationDataset
        dataset = ClassificationDataset(
            'train' if is_train else 'val',
            pipeline=transform
        )
    nb_classes = 2

    return dataset, nb_classes

In [None]:
def build_transform(is_train, args):
    resize_im = args.input_size > 32
    if is_train:
        train_transform = transforms.Compose([
                transforms.Resize((args.input_size, args.input_size//1.2)),  # 缩放
                # transforms.RandomCrop(32, padding=4),  # 随机裁剪
                transforms.ToTensor(),  # 图片转张量，同时归一化0-255 ---》 0-1
                transforms.Normalize(norm_mean, norm_std),  # 标准化均值为0标准差为1
            ])
        return train_transform

    t = []
    t.append(transforms.ToTensor())
    t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
    return transforms.Compose(t)

In [None]:
def main(args):
    init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    cudnn.benchmark = True

    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
    dataset_val, _ = build_dataset(is_train=False, args=args)
    args.distributed = False
    # print(len(dataset_train))
    tmp = build_weight(args.data_path)
    sampler_train = torch.utils.data.WeightedRandomSampler(tmp,args.nb_classes*len(dataset_train))
    sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    
    # for i in sampler_train:
    #     print(i)
    data_loader_train = torch.utils.data.DataLoader(
        dataset_train, sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

    data_loader_val = torch.utils.data.DataLoader(
        dataset_val, sampler=sampler_val,
        batch_size=250,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=False
    )
    print(f"Creating model: {args.model}")
    model = Model()

    if not args.distributed or args.rank == 0:
        print(model)
        input_tensor = torch.zeros((1, 3, 1280, 640), dtype=torch.float32)
        model.eval()
        cal_flops_params_with_fvcore(model, input_tensor)

    model.to(device)
    model_ema = None


    model_without_ddp = model

    linear_scaled_lr = args.lr * args.batch_size * get_world_size() / 512.0

    args.lr = linear_scaled_lr
    optimizer = create_optimizer(args, model_without_ddp)

    loss_scaler = NativeScaler()

    lr_scheduler, _ = create_scheduler(args, optimizer)

    criterion = LabelSmoothingCrossEntropy()

    # if args.mixup > 0.:
    #     # smoothing is handled with mixup label transform
    #     criterion = SoftTargetCrossEntropy()
    # elif args.smoothing:
    #     criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    # else:
    #     criterion = torch.nn.CrossEntropyLoss()

    # criterion = DistillationLoss(
    #     criterion, None, 'none', 0, 0
    # )
    criterion = mixloss()
    
    if not args.output_dir:
        args.output_dir = args.model
        if is_main_process():
            import os
            if not os.path.exists(args.model):
                os.mkdir(args.model)

    output_dir = Path(args.output_dir)
    if args.resume:
        if args.resume.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(
                args.resume, map_location='cpu', check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location='cpu')
        if 'model' in checkpoint:
            model_without_ddp.load_state_dict(checkpoint['model'])
        else:
            model_without_ddp.load_state_dict(checkpoint)
        if not args.finetune:
            if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
                optimizer.load_state_dict(checkpoint['optimizer'])
                # lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
                if not args.finetune:
                    args.start_epoch = checkpoint['epoch'] + 1
                if 'scaler' in checkpoint:
                    loss_scaler.load_state_dict(checkpoint['scaler'])

    if args.eval:
        if hasattr(model.module, "merge_bn"):
            print("Merge pre bn to speedup inference.")
            model.module.merge_bn()
        test_stats = evaluate(data_loader_val, model, device)
        print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
        return

    if args.throughout:
        from logger import create_logger
        logger = create_logger(output_dir=output_dir, dist_rank=get_rank(), name=args.model)
        throughput(data_loader_val, model, logger)
        return

    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()
    max_accuracy = 0.0
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)

        train_stats = train_one_epoch(
            model, criterion, data_loader_train,
            optimizer, device, epoch, loss_scaler,
            args.clip_grad, model_ema, mixup_fn,
            set_training_mode=True,
        )

        lr_scheduler.step(epoch)
        if args.output_dir:
            checkpoint_paths = [output_dir / 'checkpoint.pth']
            for checkpoint_path in checkpoint_paths:
                save_on_master({
                    'model': model_without_ddp.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'epoch': epoch,
                    'scaler': loss_scaler.state_dict(),
                    'args': args,
                }, checkpoint_path)

        test_stats = evaluate(data_loader_val, model, device)
        print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
        if test_stats["acc1"] > max_accuracy:
            if args.output_dir:
                checkpoint_paths = [output_dir / 'checkpoint_best.pth']
                for checkpoint_path in checkpoint_paths:
                    save_on_master({
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                        'epoch': epoch,
                        'args': args,
                    }, checkpoint_path)
        max_accuracy = max(max_accuracy, test_stats["acc1"])
        print(f'Max accuracy: {max_accuracy:.2f}%')

        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                     **{f'test_{k}': v for k, v in test_stats.items()},
                     'epoch': epoch}

        if args.output_dir and is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
