In [1]:
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from torchvision.models import efficientnet_v2_l
from test import EffnetV2_Key_Frame
import pandas as pd
import numpy as np
import torch.utils.data as data
import albumentations as A
import torchvision.transforms as transforms
from PIL import Image
import torchvision
import joblib
import cv2
import matplotlib.pyplot as plt
from torch.nn.utils.rnn import pad_sequence
import warnings
warnings.filterwarnings("ignore")
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from torchvision.models import efficientnet_v2_l, efficientnet_v2_s, efficientnet_v2_m
#from .cbam import CBAMBlock


In [2]:
class SpatialAttention(torch.nn.Module):
    def __init__(self, feature_map_size = 16, n_channels = 1280, use_layer_norm = False, use_alpha = True, use_skip_connection = True, use_gelu = False):
        super().__init__()
    
        self.use_alpha = use_alpha
        self.use_skip_connection = use_skip_connection
        self.use_gelu = use_gelu
        self.use_layer_norm = use_layer_norm
        self.n_channels = n_channels
        self.feature_map_size = feature_map_size
        self.keys = torch.nn.Conv2d(self.n_channels, self.n_channels, kernel_size=1, stride=1, padding=0, bias=False)
        self.queries = torch.nn.Conv2d(self.n_channels, self.n_channels, kernel_size=1, stride=1, padding=0, bias=False)
        self.values = torch.nn.Conv2d(self.n_channels, self.n_channels, kernel_size=1, stride=1, padding=0, bias=False)
        self.refine = torch.nn.Conv2d(self.n_channels, self.n_channels, kernel_size=1, stride=1, padding=0, bias=False)
        self.softmax = torch.nn.Softmax2d()
        self.gelu = torch.nn.GELU()
        if self.use_alpha:
            self.alpha = torch.nn.Parameter(torch.zeros(1))
        if self.use_layer_norm:
            self.layer_norm = torch.nn.LayerNorm([self.n_channels, self.feature_map_size, self.feature_map_size])
    def forward(self, x):
     #   print('x in spatial attention', x.shape)
        attended_features = torch.matmul(self.softmax(torch.matmul(self.keys(x).view(x.size(0), self.n_channels, -1).permute(0, 2, 1), 
                                                                   self.queries(x).view(x.size(0), self.n_channels, -1))/self.n_channels**0.5), 
                                         self.values(x).view(x.size(0), self.n_channels, -1).permute(0, 2, 1)) # (batch_size, feature_map_size * feature_map_size, n_channels)
        attended_features = attended_features.permute(0, 2, 1).view(x.size(0), self.n_channels, self.feature_map_size, self.feature_map_size) # (batch_size, n_channels, feature_map_size, feature_map_size)
      #  print('attended_features', attended_features.shape)
        attended_features = self.refine(attended_features)
        if self.use_alpha:
            #print('spatial using alpha')
            attended_features = self.alpha*attended_features + x
        else:
            #print('spatial not using alpha')
            attended_features = attended_features + x
        if self.use_layer_norm:
            #print('spatial using layer norm')
            attended_features = self.layer_norm(attended_features)
        if self.use_gelu:
            #print('spatial using gelu')
            attended_features = self.gelu(attended_features)
        return attended_features

In [3]:
class EffnetV2_L(torch.nn.Module):
    def __init__(self, out_features = 7, in_channels = 1, dropout = 0.4, use_sigmoid = False, use_attention = False):
        super().__init__()
        
        self.use_sigmoid = use_sigmoid
        self.use_attention = use_attention
        self.dropout = dropout
        self.out_features = out_features
        self.in_channels = in_channels
        self.model = efficientnet_v2_s(weights = 'EfficientNet_V2_S_Weights.IMAGENET1K_V1')
        self.model.features[0] = torch.nn.Conv2d(self.in_channels, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        self.model.avgpool = torch.nn.Identity()
        self.model.classifier = torch.nn.Sequential(nn.Dropout(self.dropout), nn.Linear(1280, self.out_features))
        self.sigmoid = torch.nn.Sigmoid()
        if self.use_attention:
            self.spatial_attention = SpatialAttention(feature_map_size = 16, n_channels=1280)
        self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
        
        
    def count_params(self):
        
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
        
    def forward(self, x):
        if self.use_attention:
            x = self.model.features(x)
            #x = self.model.conv(x)
            x = self.spatial_attention(x)
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.model.classifier(x)
        else:
            x = self.model.features(x)
            #x = self.model.conv(x)
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.model.classifier(x)
        if self.use_sigmoid:
            x = self.sigmoid(x)
        return x

In [4]:
class Fetal_frame_eval_cls(data.Dataset):
    def __init__(self, root, ann_path, transform=None, target_transform=None):

        self.data_path = root
        self.ann_path = ann_path
        self.transform = transform
        self.target_transform = target_transform
        self.database = pd.read_csv(self.ann_path)

    def _load_image(self, path):
        try:
            im = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        except:
            print("ERROR IMG NOT LOADED: ", path)
            random_img = np.random.rand(224, 224, 3) * 255
            im = Image.fromarray(np.uint8(random_img))
        return im

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is class_index of the target class.
        """
        idb = self.database.iloc[index]
        frame_idx = idb[0]
        video = idb[2]
        ps = idb[4]
        Class = idb[1]

        images = self._load_image(self.data_path  + frame_idx + '.png')
        images = np.expand_dims(images, 2)
        t = transforms.Compose([transforms.ToTensor(),
        transforms.Resize((450, 600)),
        transforms.Pad((0, 0, 0, 150), fill = 0, padding_mode = 'constant'),
        transforms.Resize((512, 512)),
        transforms.Normalize(mean=0.1354949, std=0.18222201)])
        images = t(images)
        if self.transform is not None:
            images = self.transform(images)
        
        return images, frame_idx, video, ps, Class

    def __len__(self):
        return len(self.database)
    
    def load_video(self, video):
        indexes = self.database.query('video == @video')['index']
        images = list()
        Classes = list()
        measures = list()
        pss = list()
        frames_n = list()
        measures_normalized = list()
        indexes = list()
        for i in indexes:
            image, Class, measure, ps, frame_n, measure_normalized, index = self.__getitem__(i)
            images.append(image)
            Classes.append(Class)
            measures.append(measure)
            pss.append(ps)
            frames_n.append(frame_n)
            measures_normalized.append(measure_normalized)
            indexes.append(index)
        return torch.stack(images), torch.stack(Classes), torch.stack(measures), torch.stack(pss), torch.stack(frames_n), torch.stack(measures_normalized), torch.stack(indexes)
    
    def load_batch(self, index_list):
        images = list()
        Classes = list()
        measures = list()
        pss = list()
        frames_n = list()
        measures_normalized = list()
        indexes = list()
        for i in index_list:
            image, Class, measure, ps, frame_n, measure_normalized, index = self.__getitem__(i)
            images.append(image)
            Classes.append(Class)
            measures.append(measure)
            pss.append(ps)
            frames_n.append(frame_n)
            measures_normalized.append(measure_normalized)
            indexes.append(index)
        return torch.stack(images), torch.stack(Classes), torch.stack(measures), torch.stack(pss), torch.stack(frames_n), torch.stack(measures_normalized), torch.stack(indexes)
                      

In [5]:
ann_path = '/data/kpusteln/Fetal-RL/data_preparation/data_biometry/ete_model/biometry_scaled_ps/class_data_split/val_head.csv'
root = '/data/kpusteln/fetal/fetal_extracted/'

In [6]:
dataset = Fetal_frame_eval_cls(root, ann_path)
data_loader = torch.utils.data.DataLoader(dataset, batch_size = 1, shuffle = False, num_workers = 0)

In [7]:
for image, frame_idx, video, ps, Class in data_loader:
    print(image.shape, frame_idx, video, ps, Class)
    break

torch.Size([1, 1, 512, 512]) ('2_1_1',) ('2_1',) tensor([0.1079], dtype=torch.float64) tensor([0])


In [8]:
model = EffnetV2_L(2, 1, use_attention = True, use_sigmoid = False)
checkpoint = torch.load('/data/kpusteln/Fetal-RL/src/output/effnetv2_bayes_searchbacbkone__body_parthead_bs16_lr1.15466336e-05_drop0.15_n_frames3_key_frame_attFalse_alphaTrue_use_skip_connectionTrue_use_geluTrue_use_batch_normTrue_use_headFalse/default/ckpt_epoch_10.pth', map_location = 'cpu')
model.load_state_dict(checkpoint['model'])
model.eval()

EffnetV2_L(
  (model): EfficientNet(
    (features): Sequential(
      (0): Conv2d(1, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): Sequential(
        (0): FusedMBConv(
          (block): Sequential(
            (0): Conv2dNormActivation(
              (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
              (2): SiLU(inplace=True)
            )
          )
          (stochastic_depth): StochasticDepth(p=0.0, mode=row)
        )
        (1): FusedMBConv(
          (block): Sequential(
            (0): Conv2dNormActivation(
              (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
              (2): SiLU(inplace=True)
            )
          )
          (stochastic_depth): Stochasti

In [9]:
results = torch.tensor([])
frames = []
videos = []
measures = []
Classes = torch.tensor([])
probs = torch.tensor([])
pss = torch.tensor([])

In [10]:
output = model(image)
prob = output.softmax(dim = 1).max(dim = 1)[0]
result = output.softmax(dim = 1).max(dim = 1)[1].to(torch.int64)
videos.append(video[0])
probs = torch.cat((probs, prob))
results = torch.cat((results, result))
pss = torch.cat((pss, ps))
frames.append(frame_idx[0])
Classes = torch.cat((Classes, Class))
#Classes = torch.stack(Classes)
#Classes = Classes.numpy()

In [11]:
probs

tensor([0.8689], grad_fn=<CatBackward0>)

In [21]:
# read img
img = cv2.imread('/data/kpusteln/fetal/fetal_extracted/60_3_87.png', cv2.IMREAD_GRAYSCALE)

In [22]:
t = transforms.Compose([transforms.ToTensor(),
transforms.Resize((450, 600)),
transforms.Pad((0, 0, 0, 150), fill = 0, padding_mode = 'constant'),
transforms.Resize((512, 512)),
transforms.Normalize(mean=0.1354949, std=0.18222201)])

In [23]:
img = t(img).unsqueeze(0)

In [25]:
output = model(img)

In [30]:
output

tensor([[ 2.7256, -2.7432]], grad_fn=<AddmmBackward0>)

In [31]:
prob = output.softmax(dim = 1)

In [32]:
prob

tensor([[0.9958, 0.0042]], grad_fn=<SoftmaxBackward0>)

In [25]:
probs = torch.tensor([])

In [29]:
probs = torch.cat((probs, prob))

In [30]:
probs

tensor([0.8689, 0.8689], grad_fn=<CatBackward0>)

In [99]:
tensor1 = torch.rand((16, 4, 1280))
tensor2 = torch.rand((16, 4, 1280))


In [100]:
tensor1 = tensor1.permute(0, 2, 1)

In [101]:
tensor1.shape

torch.Size([16, 1280, 4])

In [102]:
tensor2.shape

torch.Size([16, 4, 1280])

In [103]:
att = torch.bmm(tensor2, tensor1)/torch.sqrt(torch.tensor(1280.0))

In [119]:
torch.softmax(att, dim = 1)

tensor([[[0.2463, 0.2650, 0.2361, 0.2188],
         [0.2427, 0.2209, 0.2358, 0.2575],
         [0.2818, 0.2735, 0.2967, 0.3114],
         [0.2291, 0.2406, 0.2314, 0.2123]],

        [[0.2794, 0.2888, 0.2456, 0.2823],
         [0.2311, 0.1973, 0.2233, 0.1805],
         [0.2743, 0.2631, 0.3127, 0.3267],
         [0.2152, 0.2508, 0.2185, 0.2104]],

        [[0.2993, 0.2886, 0.2563, 0.2999],
         [0.2840, 0.2811, 0.2994, 0.2826],
         [0.1726, 0.1746, 0.1953, 0.1671],
         [0.2442, 0.2557, 0.2490, 0.2504]],

        [[0.2529, 0.2530, 0.2168, 0.2142],
         [0.2559, 0.2382, 0.2111, 0.2441],
         [0.2288, 0.2317, 0.2629, 0.2994],
         [0.2624, 0.2772, 0.3093, 0.2423]],

        [[0.2327, 0.2092, 0.2440, 0.2629],
         [0.2743, 0.2872, 0.2605, 0.2544],
         [0.2235, 0.2511, 0.2422, 0.2015],
         [0.2695, 0.2525, 0.2533, 0.2812]],

        [[0.2044, 0.2203, 0.2463, 0.2121],
         [0.2514, 0.2571, 0.2221, 0.2142],
         [0.2430, 0.2799, 0.2575, 0.2540],
 