# import

In [2]:
import os
import glob
import numpy as np
import pandas as pd
import random
import math
import gc
import cv2
from tqdm import tqdm
import time
from functools import lru_cache
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import timm
from timm.scheduler import CosineLRScheduler
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
from sklearn.metrics import matthews_corrcoef
from sklearn.model_selection import train_test_split
import time
import io
import os
import sys
import time
import json
import math
import glob
import datetime
import argparse
import numpy as np
from pathlib import Path
from collections import defaultdict, deque

import torch
import torchvision
import torch.distributed as dist
from torch import nn, einsum
from torch.nn import functional as F
import torch.backends.cudnn as cudnn
import torch.utils.checkpoint as checkpoint

from typing import Iterable, Optional
from timm.models import create_model
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
from timm.data import Mixup,create_transform
from timm.models.registry import register_model
from timm.models.layers import DropPath, trunc_normal_
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.utils import accuracy, ModelEma,NativeScaler, get_state_dict, ModelEma

from torchvision import datasets, transforms
from torchvision.datasets.folder import ImageFolder, default_loader

from functools import partial
from einops import rearrange

# modle build

## device

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

## cfg

In [4]:
def get_args_parser():
    parser = argparse.ArgumentParser('Next-ViT training and evaluation script', add_help=False)
    parser.add_argument('--batch-size', default=16, type=int)
    parser.add_argument('--epochs', default=300, type=int)

    # Model parameters
    parser.add_argument('--model', default='pvt_small', type=str, metavar='MODEL',
                        help='Name of model to train')
    parser.add_argument('--input-Hsize', default=1280, type=int, help='images input size')
    parser.add_argument('--input-Wsize', default=640, type=int, help='images input size')
    parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
                        help='Dropout rate (default: 0.)')
    parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
                        help='Drop path rate (default: 0.1)')
    parser.add_argument('--flops', type=float, default=0.1, metavar='PCT',
                        help='Drop path rate (default: 0.1)')
    # Optimizer parameters
    parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
                        help='Optimizer (default: "adamw"')
    parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
                        help='Optimizer Epsilon (default: 1e-8)')
    parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
                        help='Optimizer Betas (default: None, use opt default)')
    parser.add_argument('--clip-grad', type=float, default=5, metavar='NORM',
                        help='Clip gradient norm (default: None, no clipping)')
    parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument('--weight-decay', type=float, default=0.05,
                        help='weight decay (default: 0.05)')
    # Learning rate schedule parameters
    parser.add_argument('--sched', default='sched', type=str, metavar='SCHEDULER',
                        help='LR scheduler (default: "cosine"')
    parser.add_argument('--lr', type=float, default=5e-6, metavar='LR',
                        help='learning rate (default: 5e-4)')
    parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
                        help='learning rate noise on/off epoch percentages')
    parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
                        help='learning rate noise limit percent (default: 0.67)')
    parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
                        help='learning rate noise std-dev (default: 1.0)')
    parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
                        help='warmup learning rate (default: 1e-6)')
    parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
                        help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')

    parser.add_argument('--decay-epochs', type=float, default=40, metavar='N',
                        help='epoch interval to decay LR')
    parser.add_argument('--warmup-epochs', type=int, default=10, metavar='N',
                        help='epochs to warmup LR, if scheduler supports')
    parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
                        help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
    parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
                        help='patience epochs for Plateau LR scheduler (default: 10')
    parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
                        help='LR decay rate (default: 0.1)')

    # Augmentation parameters
    parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
                        help='Color jitter factor (default: 0.4)')
    parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
                        help='Use AutoAugment policy. "v0" or "original". " + \
                             "(default: rand-m9-mstd0.5-inc1)'),
    parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
    parser.add_argument('--train-interpolation', type=str, default='bicubic',
                        help='Training interpolation (random, bilinear, bicubic default: "bicubic")')

    parser.add_argument('--repeated-aug', action='store_true')
    parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
    parser.set_defaults(repeated_aug=False)

    # * Random Erase params
    parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
                        help='Random erase prob (default: 0.25)')
    parser.add_argument('--remode', type=str, default='pixel',
                        help='Random erase mode (default: "pixel")')
    parser.add_argument('--recount', type=int, default=1,
                        help='Random erase count (default: 1)')
    parser.add_argument('--resplit', action='store_true', default=False,
                        help='Do not random erase first (clean) augmentation split')

    # * Mixup params
    parser.add_argument('--mixup', type=float, default=0.8,
                        help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
    parser.add_argument('--cutmix', type=float, default=1.0,
                        help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
    parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
                        help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
    parser.add_argument('--mixup-prob', type=float, default=1.0,
                        help='Probability of performing mixup or cutmix when either/both is enabled')
    parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
                        help='Probability of switching to cutmix when both mixup and cutmix enabled')
    parser.add_argument('--mixup-mode', type=str, default='batch',
                        help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')

    # * Finetuning params
    parser.add_argument('--finetune', action='store_true', help='Perform finetune.')

    # Dataset parameters
    parser.add_argument('--data-path', default='../../datasets/imagenet_full_size/061417/', type=str,
                        help='dataset path')
    parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'],
                        type=str, help='Image Net dataset path')
    parser.add_argument('--use-mcloader', action='store_true', default=False, help='Use mcloader')
    parser.add_argument('--inat-category', default='name',
                        choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],
                        type=str, help='semantic granularity')

    parser.add_argument('--output-dir', default='../../outputdir',
                        help='path where to save, empty for no saving')
    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--resume', default='./nextvit_small_in1k6m_384.pth', help='resume from checkpoint')
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
    parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation')
    parser.add_argument('--num_workers', default=10, type=int)
    parser.add_argument('--pin-mem', action='store_true',
                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
    parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
                        help='')
    parser.set_defaults(pin_mem=True)

    # distributed training parameters
    parser.add_argument('--world_size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')

    # test throught
    parser.add_argument('--throughout', action='store_true', help='Perform throughout only')
    return parser

## Backbone

In [5]:
class Backbone(nn.Module):
    def __init__(self, ):
        super(Backbone, self).__init__()
        self.register_buffer('mean', torch.FloatTensor([0.5, 0.5, 0.5]).reshape(1,3,1,1))
        self.register_buffer('std',  torch.FloatTensor([0.5, 0.5, 0.5]).reshape(1,3,1,1))
        self.encoder = timm.create_model ('efficientnetv2_m',pretrained=False, drop_rate = 0.2, drop_path_rate = 0.1)

    def forward(self, x):
        features=[]
        def hook(module, input, output):
            features.append(input)
            return None
        x = (x - self.mean) / self.std
        self.encoder.blocks[4].register_forward_hook(hook)
        self.encoder.blocks[5].register_forward_hook(hook)
        self.encoder.blocks[6].register_forward_hook(hook)
        x = self.encoder.forward_features(x)
        return x,features

In [7]:
net = Backbone().to(device)

In [9]:
i_m = np.random.random(size=(4, 3, 1024, 512))
# i_a = np.random.random(size=(1, 160, 32, 16))
i_m = torch.tensor(i_m).to(device).to(torch.float32)
# i_a = torch.tensor(i_a).to(device).to(torch.float32)
i_m.shape

torch.Size([4, 3, 1024, 512])

In [11]:
a,b = net(i_m)

In [12]:
a.shape

torch.Size([4, 1280, 32, 16])

In [15]:
b[0][0].shape

torch.Size([4, 160, 64, 32])

In [17]:
b[1][0].shape

torch.Size([4, 176, 64, 32])

In [18]:
b[2][0].shape

torch.Size([4, 304, 32, 16])

In [3]:
class GlobalConsistency(nn.Module):
    def __init__(self,dim):
        super(GlobalConsistency, self).__init__()
        self.pj = nn.Linear(dim,128)
        self.project = nn.Linear(dim,dim) #<todo> try mlp?

    def forward(self, u_m, u_a):
        B, C, H, W = u_m.shape

        g_m = F.adaptive_max_pool2d(u_m,1)
        g_a = F.adaptive_max_pool2d(u_a,1)
        g_m = torch.flatten(g_m, 1)
        g_a = torch.flatten(g_a, 1)
        pj_m = self.pj(g_m)
        pj_a = self.pj(g_a)
        p_a = self.project(g_m)
        p_m = self.project(g_a)

        return g_a, p_a ,pj_m,pj_a

In [4]:
class E_MHCA(nn.Module):
    """
    Efficient Multi-Head Cross Attention
    """
    def __init__(self, dim, out_dim=None, head_dim=16, #随着H*W变化 16/32
                 qkv_bias=True, qk_scale=None,
                 attn_drop=0, proj_drop=0., sr_ratio=1):
        super().__init__()
        self.dim = dim
        self.out_dim = out_dim if out_dim is not None else dim
        self.num_heads = self.dim // head_dim
        self.scale = qk_scale or head_dim ** -0.5
        self.q_m = nn.Linear(dim, self.dim, bias=qkv_bias)
        self.k_m = nn.Linear(dim, self.dim, bias=qkv_bias)
        self.v_m = nn.Linear(dim, self.dim, bias=qkv_bias)
        self.proj_m = nn.Linear(self.dim, self.out_dim)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)
        self.q_a = nn.Linear(dim, self.dim, bias=qkv_bias)
        self.k_a = nn.Linear(dim, self.dim, bias=qkv_bias)
        self.v_a = nn.Linear(dim, self.dim, bias=qkv_bias)
        self.proj_a = nn.Linear(self.dim, self.out_dim)
        

        self.sr_ratio = sr_ratio
        self.N_ratio = sr_ratio ** 2
        if sr_ratio > 1:
            self.sr = nn.AvgPool1d(kernel_size=self.N_ratio, stride=self.N_ratio)
            self.norm = nn.BatchNorm1d(dim, eps=NORM_EPS)
        self.is_bn_merged = False

    def merge_bn(self, pre_bn):
        merge_pre_bn(self.q, pre_bn)
        if self.sr_ratio > 1:
            merge_pre_bn(self.k, pre_bn, self.norm)
            merge_pre_bn(self.v, pre_bn, self.norm)
        else:
            merge_pre_bn(self.k, pre_bn)
            merge_pre_bn(self.v, pre_bn)
        self.is_bn_merged = True

    def forward(self, x_mh,x_ah,x_mv,x_av):
        B, N, C = x_m.shape
        # B, N, C = x_a.shape
        q_m = self.q_m(x_m)
        q_m = q_m.reshape(B, N, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)
        v_m = self.v_m(x_m)
        v_m = v_m.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)
        k_m = self.k_m(x_m)
        k_m = k_m.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 3, 1)
        
        q_a = self.q_a(x_a)
        q_a = q_a.reshape(B, N, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)

        k_a = self.k_a(x_a)
        k_a = k_a.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 3, 1)
        v_a = self.v_a(x_a)
        v_a = v_a.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)
        print(q_m.shape)
        print(k_a.shape)
        attn_m = (q_m @ k_a.transpose(-2, -1)) * self.scale
        attn_m = attn_m.softmax(dim=-1)
        attn_m = self.attn_drop(attn_m)

        attn_a = (q_a @ k_m.transpose(-2, -1)) * self.scale
        attn_a = attn_a.softmax(dim=-1)
        attn_a = self.attn_drop(attn_a)


        x_m = (attn_m @ v_m).transpose(1, 2).reshape(B, N, C)
        x_m = self.proj_m(x_m)
        x_m = self.proj_drop(x_m)

        x_a = (attn_a @ v_a).transpose(1, 2).reshape(B, N, C)
        x_a = self.proj_a(x_a)
        x_a = self.proj_drop(x_a)
        return x_m, x_a

In [5]:
net = E_MHCA(32)

In [6]:
i = np.random.random(size=(1, 16, 32))
i.shape

(1, 16, 32)

In [7]:
net.to(device)
# i.to(device)
i = torch.tensor( i).to(device).to(torch.float32)

In [8]:
x,x2 = net(i,i)

torch.Size([1, 2, 16, 16])
torch.Size([1, 2, 16, 16])


In [90]:
x

tensor([[[ 0.1570, -0.5715,  0.3364, -0.1431,  0.1250,  0.2233, -0.0127,
          -0.1961, -0.0727, -0.2810,  0.1704,  0.2176, -0.2890,  0.0097,
           0.1279,  0.0951,  0.2664,  0.1611,  0.2087,  0.3318, -0.2636,
          -0.5694,  0.0494,  0.2268,  0.2702, -0.3870, -0.3487, -0.1708,
          -0.5044,  0.0967,  0.0758, -0.0313],
         [ 0.1592, -0.5733,  0.3386, -0.1478,  0.1278,  0.2261, -0.0143,
          -0.1968, -0.0744, -0.2779,  0.1693,  0.2218, -0.2902,  0.0081,
           0.1221,  0.0972,  0.2655,  0.1604,  0.2053,  0.3373, -0.2614,
          -0.5762,  0.0426,  0.2278,  0.2748, -0.3941, -0.3523, -0.1711,
          -0.5060,  0.0976,  0.0757, -0.0317],
         [ 0.1563, -0.5721,  0.3366, -0.1450,  0.1259,  0.2239, -0.0133,
          -0.1956, -0.0734, -0.2794,  0.1695,  0.2178, -0.2899,  0.0089,
           0.1261,  0.0960,  0.2657,  0.1607,  0.2074,  0.3338, -0.2628,
          -0.5711,  0.0475,  0.2278,  0.2720, -0.3891, -0.3494, -0.1713,
          -0.5044,  0.0979,  0

### torch.Size([1, 160, 64, 32])
### torch.Size([1, 176, 64, 32])
### torch.Size([1, 304, 32, 16])

In [110]:
NORM_EPS = 1e-5
class PatchEmbed(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 stride=1,
                 mode = "V"
                ):
        super(PatchEmbed, self).__init__()
        norm_layer = partial(nn.BatchNorm2d, eps=NORM_EPS)
        if stride == 4 and mode == "V":
            self.avgpool = nn.AvgPool2d((4, 32), stride=4, ceil_mode=True, count_include_pad=False)
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
            self.norm = norm_layer(out_channels)
        elif stride == 2 and mode == "V":
            self.avgpool = nn.AvgPool2d((2, 16), stride=2, ceil_mode=True, count_include_pad=False)
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
            self.norm = norm_layer(out_channels)
        elif stride == 2 and mode == "H":
            self.avgpool = nn.AvgPool2d((64, 2), stride=2, ceil_mode=True, count_include_pad=False)
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
            self.norm = norm_layer(out_channels)
        elif stride == 1 and mode == "H":
            self.avgpool = nn.AvgPool2d((32, 1), stride=1, ceil_mode=True, count_include_pad=False)
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
            self.norm = norm_layer(out_channels)
        elif in_channels != out_channels:
            self.avgpool = nn.Identity()
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
            self.norm = norm_layer(out_channels)
        else:
            self.avgpool = nn.Identity()
            self.conv = nn.Identity()
            self.norm = nn.Identity()

    def forward(self, x):
        return self.norm(self.conv(self.avgpool(x)))

In [119]:
pe = PatchEmbed(160,32,1,"H").to(device)

In [120]:
i = np.random.random(size=(1, 160, 32, 16))
i = torch.tensor( i).to(device).to(torch.float32)
i.shape

torch.Size([1, 160, 32, 16])

In [121]:
x = pe(i)

In [122]:
x.shape

torch.Size([1, 32, 1, 16])

In [None]:
class ShardExamine(nn.Module):
    
    def __init__(self,dim,
                 num_head=8, 
                 qkv_bias=False, 
                 attn_drop=0., 
                 proj_drop=0.,
                 in_channels,
                 out_channels,
                 stride=1
                ):
        super(LocalCoccurrence, self).__init__()
        self.PEV = PatchEmbed(in_channels,out_channels,stride*2,mode = "V")
        self.PEH = PatchEmbed(in_channels,out_channels,stride,mode = "H")
        self.norm1 = nn.LayerNorm(dim)
        self.attn  = CrossAttention(dim, num_head, qkv_bias, attn_drop, proj_drop)

    def forward(self, u_m, u_a):
        u_mv = self.PEV(u_m)
        u_av = self.PEV(u_a)
        u_m = self.PEH(u_m)
        u_a = self.PEH(u_a)
        B,C,H,W = u_mv.shape
        L = H*W
        dim = C

        u_mv = u_mv.reshape(B,dim,L).permute(0,2,1)
        u_av = u_av.reshape(B,dim,L).permute(0,2,1)
        u_m = u_m.reshape(B,dim,L).permute(0,2,1)
        u_a = u_a.reshape(B,dim,L).permute(0,2,1)

        u_mv = self.norm1(u_mv)
        u_av = self.norm1(u_av)
        u_m = self.norm1(u_m)
        u_a = self.norm1(u_a)
        x_m, x_a = self.attn(x_m, x_a)
        x_mv, x_av = self.attn(x_mv, x_a)
        gap_m = x_m.mean(1)
        gap_a = x_a.mean(1)
        return gap_m, gap_a

In [None]:
class SCN(nn.Module)
"""
Shard Contrastive Net
"""
    def load_pretrain(self, ):
        return

    def __init__(self,):
        super(Net, self).__init__()
        self.output_type = ['inference', 'loss']


        self.backbone = Backbone()
        dim = 1280

        self.lc  = LocalCoccurrence(dim)
        self.gl  = GlobalConsistency(dim)
        self.ape = 
        self.mlp = nn.Sequential(
            nn.LayerNorm(dim*3),
            nn.Linear(dim*3, dim),
            nn.GELU(),
            nn.Linear(dim, dim),
        )#<todo> mlp needs to be deep if backbone is strong?
        self.cancer = nn.Linear(dim,1)

    def forward(self, batch):
        x = batch['image']
        batch_size,num_view,C,H,W = x.shape
        x = x.reshape(-1, C, H, W)

        u = self.backbone(x)
        _,c,h,w = u.shape

        u = u.reshape(batch_size,num_view,c,h,w)
        u_m = u[:,0]
        u_a = u[:,1]
        gap_m, gap_a = self.lc(u_m, u_a)

        g_m, p_m, g_a, p_a = self.gl(u_m, u_a)
        gp_m = g_m + p_m

        last = torch.cat([gp_m, gap_m, gap_a ],-1)
        last = self.mlp(last)
        cancer = self.cancer(last).reshape(-1)


        output = {}
        if  'loss' in self.output_type:
            output['cancer_loss'] = F.binary_cross_entropy_with_logits(cancer, batch['cancer'])
            output['global_loss'] = criterion_global_consistency(g_m, p_m, g_a, p_a)


        if 'inference' in self.output_type:
            output['cancer'] = torch.sigmoid(cancer)

        return output

In [9]:
class basNet(nn.Module):
    """
    only two viwe
    """
    def load_pretrain(self, ):
        return

    def __init__(self,):
        super(basNet, self).__init__()
        # self.output_type = ['inference', 'loss']


        self.backbone =timm.create_model ('efficientnetv2_m',
                                          pretrained=False, 
                                          drop_rate = 0.2, 
                                          drop_path_rate = 0.1,
                                          num_classes=512
                                         )

        # dim = 1280

        # self.lc  = LocalCoccurrence(dim)
        # self.gl  = GlobalConsistency(dim)
        self.mlp = nn.Sequential(
            nn.LayerNorm(1024),
            nn.Linear(1024, 512),
            nn.GELU(),
            nn.Linear(512, 128),
        )#<todo> mlp needs to be deep if backbone is strong?
        self.cancer = nn.Linear(128,1)

    def forward(self, x):
        # x = batch['image']
        batch_size,C,H,W = x.shape
        x = x.reshape(-1, C, H, W)
        x_m =torch.tensor( np.array_split(batch,2,axis=3)[0])

        # print(x_m.shape)
        x_a = torch.tensor( np.array_split(batch,2,axis=3)[1])
        x_m = self.backbone(x_m)
        x_a = self.backbone(x_a)
        last = torch.cat([x_m, x_a ],-1)
        last = self.mlp(last)
        cancer = self.cancer(last).reshape(-1)

        return torch.sigmoid(cancer)

In [10]:
batch = np.random.random(size=(1,3,1024,1024))
batch = torch.tensor( batch).to(device).to(torch.float32)
batch_size,C,H,W = batch.shape
x = batch.reshape(-1, C, H, W)

In [11]:
x.shape

torch.Size([1, 3, 1024, 1024])

In [12]:
model = basNet().to(device)

In [13]:
x = model(x)

  x_m =torch.tensor( np.array_split(batch,2,axis=3)[0])
  x_a = torch.tensor( np.array_split(batch,2,axis=3)[1])


In [14]:
x.shape

torch.Size([1])

In [12]:
m = Backbone()

In [28]:
batch = np.random.random(size=(8,3,256,256))
batch = torch.tensor( batch).to(device).to(torch.float32)
batch_size,C,H,W = batch.shape
x = batch.reshape(-1, C, H, W)

In [29]:
x.shape

torch.Size([8, 3, 256, 256])

In [30]:
m.to(device)
x = m(x)
x[0].shape

torch.Size([8, 1280, 8, 8])