# TMANet

В этом ноутбуке будет реализована сеть TMANet из этой статьи: https://arxiv.org/abs/2102.08643

## Датасет:

В качестве датасета возьмем этот датасет: https://www.kaggle.com/datasets/carlolepelaars/camvid

In [1]:
#from google.colab import files
#files.upload()

#!mkdir -p ~/.kaggle
#!cp kaggle.json ~/.kaggle/
#!pip install kaggle
#!chmod 600 /root/.kaggle/kaggle.json
#!kaggle datasets download -d carlolepelaars/camvid
#!unzip camvid.zip

Посмотрим на значения классов: 

In [2]:
import pandas as pd 

df = pd.read_csv('../input/camvid/CamVid/class_dict.csv')
df

Распарсим все изображения:

In [3]:
import torch
from torchvision import transforms
from PIL import Image

def get_tensor_image_from_path(path):
    img = Image.open(path).resize((64, 64))
    convert_tensor = transforms.ToTensor()
    return convert_tensor(img)


In [4]:
import os
name_video2images = dict()
name_video2labels = dict()

for name_file in os.walk('../input/camvid/CamVid'):
    if name_file[0] == '../input/camvid/CamVid/train':
        for elem in name_file[2]:
            if name_video2images.get(elem[:6], 0) == 0:
                name_video2images[elem[:6]] = []
            name_video2images[elem[:6]].append(elem)
    if name_file[0] == '../input/camvid/CamVid/train_labels':
         for elem in name_file[2]:
            if name_video2labels.get(elem[:6], 0) == 0:
                name_video2labels[elem[:6]] = []
            name_video2labels[elem[:6]].append(elem)

In [5]:
print(len(name_video2images))
print(len(name_video2labels))
assert set(name_video2images.keys()) == set(name_video2labels.keys())

In [6]:
for cluster in set(name_video2images.keys()):
    print(len(name_video2images[cluster]),  len(name_video2labels[cluster]))
    assert len(name_video2images[cluster]) == len(name_video2labels[cluster])

Посортируем в каждом видео данные:

In [7]:
for key in name_video2images.keys():
    name_video2images[key] = sorted(name_video2images[key])
    name_video2labels[key] = sorted(name_video2labels[key])

Проверим у всех ли изображений одинаковый размер:

In [8]:
H, W = -1, -1
image_dir = '../input/camvid/CamVid/train/'
for key in name_video2images.keys():
    for path in name_video2images[key]:
        cur_tensor = get_tensor_image_from_path(image_dir + path)
        if H == -1:
            H = int(cur_tensor.shape[1])
            W = int(cur_tensor.shape[2])
        assert H == int(cur_tensor.shape[1])
        assert W == int(cur_tensor.shape[2])

Наконец, напишем класс для изображений:

In [9]:
from torch.utils.data import DataLoader, Dataset

class VideoDataset(Dataset):
    def __init__(self, image_dict, labels_dict, image_dir, label_dir):
        self.image_dict = image_dict
        self.labels_dict = labels_dict
        self.image_dir = image_dir
        self.label_dir = label_dir

    def __getitem__(self, index):
        image_tensors = []
        label_tensors = []
        segment = list(self.image_dict.keys())[index]
        for image_path in self.image_dict[segment]:
            image_tensors.append(get_tensor_image_from_path(self.image_dir + image_path))
        for image_path in self.labels_dict[segment]:
            label_tensors.append(get_tensor_image_from_path(self.label_dir + image_path))
        return image_tensors, label_tensors

    def __len__(self):
        return len(self.image_dict.keys())

In [10]:
train_dataset = VideoDataset(name_video2images, name_video2labels, '../input/camvid/CamVid/train/', '../input/camvid/CamVid/train_labels/')

Сделаем также датасет для валидации и теста:

In [11]:
name_video2images_valid = dict()
name_video2labels_valid = dict()

for name_file in os.walk('../input/camvid/CamVid'):
    if name_file[0] == '../input/camvid/CamVid/val':
        for elem in name_file[2]:
            if name_video2images_valid.get(elem[:6], 0) == 0:
                name_video2images_valid[elem[:6]] = []
            name_video2images_valid[elem[:6]].append(elem)
    if name_file[0] == '../input/camvid/CamVid/val_labels':
         for elem in name_file[2]:
            if name_video2labels_valid.get(elem[:6], 0) == 0:
                name_video2labels_valid[elem[:6]] = []
            name_video2labels_valid[elem[:6]].append(elem)

for key in name_video2images_valid.keys():
    name_video2images_valid[key] = sorted(name_video2images_valid[key])
    name_video2labels_valid[key] = sorted(name_video2labels_valid[key])

In [12]:
valid_dataset = VideoDataset(name_video2images_valid, name_video2labels_valid,'../input/camvid/CamVid/val/', '../input/camvid/CamVid/val_labels/')

Датасет готов, попробуем какую-нибудь простую модель чтобы сравнивать результаты, например UNet:

In [13]:
import torch
import torch.nn as nn 

class Block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, X):
        return self.model(X)


In [14]:
class UNet(nn.Module):
    def __init__(self, in_channels = 3):
        super().__init__()
        self.conv1 = nn.Sequential(
                  Block(in_channels, 16),
                  Block(16, 16), 
                  Block(16, 16)
        )
        self.pool1 = nn.MaxPool2d(3, 1, 1, return_indices=True)
        self.conv2 = nn.Sequential(
                Block(16, 8),
                Block(8, 8), 
                Block(8, 8)
        )     
        self.pool2 = nn.MaxPool2d(3, 1, 1, return_indices=True)
        self.conv3 = Block(8, 8)

        self.unpool1 = nn.MaxUnpool2d(3, 1, 1)

        self.up1 = nn.Sequential(
                Block(8, 8),
                Block(8, 16)
        )
        self.unpool2 = nn.MaxUnpool2d(3, 1, 1)

        self.up2 = nn.Sequential(
                Block(16, 16),
                Block(16, 3)
        )
        self.last = nn.Sequential(
            Block(6, 9),
            Block(9, 6),
            nn.Conv2d(6, 3, 3, 1, 1)
        )
        self.sigmoid = nn.Sigmoid()
        self.dropout = nn.Dropout(0.1)

    def forward(self, X):
        hidden = self.conv1(X)
        hidden, ind1 = self.pool1(hidden) 
        hidden = self.conv2(hidden)
        hidden = self.dropout(hidden)
        hidden, ind2 = self.pool2(hidden)
        hidden = self.conv3(hidden) 
        hidden = self.dropout(hidden)
        hidden = self.unpool1(hidden, ind2)
        hidden = self.up1(hidden)
        hidden = self.dropout(hidden)
        hidden = self.unpool2(hidden, ind1)
        hidden = self.up2(hidden)
        hidden = torch.cat([hidden, X], 1)
        return self.sigmoid(self.last(hidden))

Обучим UNet и посмотрим результаты:

In [15]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

In [16]:
H, W

In [17]:
model_baseline = UNet().to(device)
optimizer = torch.optim.Adam(model_baseline.parameters(), lr = 1e-3)
criterion = nn.MSELoss()

In [18]:
from tqdm.auto import tqdm
num_epochs = 40 

for epoch in tqdm(range(num_epochs)):
    model_baseline.train()
    sum_loss, cnt_loss = 0, 0
    for list_videos, list_labels in train_dataset:
        for i in range(len(list_videos)):
            optimizer.zero_grad()
            image = list_videos[i].unsqueeze(0).to(device)
            label = list_labels[i].unsqueeze(0).to(device)
            output = model_baseline(image)
            loss = criterion(output, label)
            loss.backward()
            optimizer.step() 
            sum_loss += loss.item()
            cnt_loss += 1
    sum_val, cnt_val = 0, 0
    with torch.no_grad():
        for list_videos, list_labels in valid_dataset:
            for i in range(len(list_videos)):
                image = list_videos[i].unsqueeze(0).to(device)
                label = list_labels[i].unsqueeze(0).to(device)
                output = model_baseline(image)
                loss = criterion(output, label)
                sum_val += loss.item()
                cnt_val += 1
    print(f"Mean train loss: {sum_loss / cnt_loss} | Mean valid loss: {sum_val / cnt_val}")


Наконец, реализуем саму модель:

In [19]:
class TMANet(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = nn.Sequential(
            Block(3, 4),
            Block(4, 8),
            Block(8, 10)
        )
        self.convA = nn.Sequential( # blue
            Block(10, 10),
            Block(10, 10)
        )
        self.convB = nn.Sequential( # green
            Block(10, 10),
            Block(10, 10)
        )
        self.previous_images = []
        self.softmax = nn.Softmax()
        self.prelast_net = nn.Sequential(
            Block(20, 16),
            Block(16, 8),
            Block(8, 3)
        )
        self.last_net = nn.Sequential(
            Block(6, 4),
            Block(4, 3)
        )
        self.sigmoid = nn.Sigmoid()
        self.lim_last = 6

    def forward(self, X):
        self.previous_images.append(X)
        if len(self.previous_images) > self.lim_last:
            popped_element = self.previous_images.pop(0)
            del popped_element
        memory_sequence = torch.cat(self.previous_images, 0)
        backbone_output = self.backbone(memory_sequence)
        hidden = self.convA(backbone_output) #T * cv * H * W
        T = int(hidden.shape[0])
        cv = int(hidden.shape[1])
        H = int(hidden.shape[2])
        W = int(hidden.shape[3])
        MV = hidden.permute(0, 2, 3, 1).reshape(T * H * W, -1)
        hidden_1 = self.convB(backbone_output) # T * ck * H * W   
        ck = int(hidden_1.shape[1])
        MK = hidden_1.permute(1, 0, 2, 3).reshape(ck, -1)
        
        backbone_cur_output = self.backbone(X)
        hidden_cur = self.convA(backbone_cur_output)
        QK = hidden_cur.permute(0, 2, 3, 1).reshape(H * W, -1)
        hidden_cur_1 = self.convB(backbone_cur_output)
        QV = hidden_cur_1

        S = self.softmax(QK @ MK)
        QK_add = S @ MV
        QK_add = QK_add.permute(1, 0).reshape(1, cv, H, W)
        QV = torch.cat((QV, QK_add), 1)
        output = self.prelast_net(QV)
        output = torch.cat([output, X], 1)
        output = self.last_net(output)
        output = self.sigmoid(output)
        return output
    def clear_history(self):
        for elem in self.previous_images:
            del elem

In [20]:
model = TMANet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
criterion = nn.MSELoss()

In [None]:
from tqdm.auto import tqdm
num_epochs = 30 

for epoch in tqdm(range(num_epochs)):
    model.train()
    sum_loss, cnt_loss = 0, 0
    for list_videos, list_labels in train_dataset:
        for i in range(len(list_videos)):
            optimizer.zero_grad()
            image = list_videos[i].unsqueeze(0).to(device)
            label = list_labels[i].unsqueeze(0).to(device)
            output = model(image)
            loss = criterion(output, label)
            loss.backward()
            optimizer.step() 
            sum_loss += loss.item()
            cnt_loss += 1
        model.clear_history()
    sum_val, cnt_val = 0, 0
    with torch.no_grad():
        for list_videos, list_labels in valid_dataset:
            for i in range(len(list_videos)):
                image = list_videos[i].unsqueeze(0).to(device)
                label = list_labels[i].unsqueeze(0).to(device)
                output = model(image)
                loss = criterion(output, label)
                sum_val += loss.item()
                cnt_val += 1
            model.clear_history()
    print(f"Mean train loss: {sum_loss / cnt_loss} | Mean valid loss: {sum_val / cnt_val}")


Видно, что TMANet уступает обычному UNet, вероятно, нужно дольше обучать