<a href="https://colab.research.google.com/github/haminhtien99/re3-pytorch/blob/master/re3-pytorch-colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Xây dựng mạng từ ban đầu. Cấu trúc mạng làm theo tác giả re3-pytorch

### Library

In [None]:
!pip install lime

Collecting lime
  Downloading lime-0.2.0.1.tar.gz (275 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/275.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━[0m [32m122.9/275.7 kB[0m [31m3.7 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m275.7/275.7 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: lime
  Building wheel for lime (setup.py) ... [?25l[?25hdone
  Created wheel for lime: filename=lime-0.2.0.1-py3-none-any.whl size=283834 sha256=3e52f05c621bc42f397bf489379b165276e46ee74392309089314ad65f5b6b72
  Stored in directory: /root/.cache/pip/wheels/fd/a2/af/9ac0a1a85a27f314a06b39e1f492bee1547d52549a4606ed89
Successfully built lime
Installing collected packages: lime
Successfully installed lime-0.2.0.1


In [None]:
from copy import deepcopy

import matplotlib.pyplot as plt
from matplotlib.image import imread
from mpl_toolkits import mplot3d
from matplotlib import gridspec
from PIL import Image
import io
from urllib.request import urlopen
from lime import lime_image
from skimage.segmentation import mark_boundaries

from tqdm.notebook import tqdm
import numpy as np
import requests
import torch

from sklearn.metrics import classification_report
from torch.utils.tensorboard import SummaryWriter

from torchvision import datasets, transforms

import os.path
import sys

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%cd /content/drive/MyDrive/Re3-Object-Tracking/re3-pytorch
from re3_utils.pytorch_util import pytorch_util_functions as pt_util
from re3_utils.pytorch_util.CaffeLSTMCell import CaffeLSTMCell

/content/drive/MyDrive/Re3-Object-Tracking/re3-pytorch


In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

### Train network

In [None]:
def train_on_batch(model, x_batch, y_batch, optimizer, loss_function):
    model.train()
    model.zero_grad()

    output = model(x_batch.to(device))

    loss = loss_function(output, y_batch.to(device))
    loss.backward()

    optimizer.step()
    return loss.cpu().item()

In [None]:
def train_epoch(train_generator, model, loss_function, optimizer, callback = None):
    epoch_loss = 0
    total = 0
    for it, (batch_of_x, batch_of_y) in enumerate(train_generator):
        batch_loss = train_on_batch(model, batch_of_x.to(device), batch_of_y.to(device), optimizer, loss_function)

        if callback is not None:
            callback(model, batch_loss)

        epoch_loss += batch_loss*len(batch_of_x)
        total += len(batch_of_x)

    return epoch_loss/total

In [None]:
def trainer(count_of_epoch,
            batch_size,
            dataset,
            model,
            loss_function,
            optimizer,
            lr = 0.001,
            callback = None):

    optima = optimizer(model.parameters(), lr=lr)

    iterations = tqdm(range(count_of_epoch), desc='epoch')
    iterations.set_postfix({'train epoch loss': np.nan})
    for it in iterations:
        batch_generator = tqdm(
            torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True),
            leave=False, total=len(dataset)//batch_size+(len(dataset)%batch_size> 0))

        epoch_loss = train_epoch(train_generator=batch_generator,
                    model=model,
                    loss_function=loss_function,
                    optimizer=optima,
                    callback=callback)

        iterations.set_postfix({'train epoch loss': epoch_loss})

### Network Structure

In [None]:
class ConvBlock(nn.Module):
    """
    Helper module that consists of a Conv -> Norm -> ReLU
    """

    def __init__(self, in_channels, out_channels, padding=1, kernel_size=3, stride=1, with_nonlinearity=True):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, padding=padding, kernel_size=kernel_size, stride=stride)
        self.bn = nn.GroupNorm(32, out_channels)
        self.nonlinearity = nn.ELU(inplace=True)
        self.with_nonlinearity = with_nonlinearity

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        if self.with_nonlinearity:
            x = self.nonlinearity(x)
        return x

In [None]:
class Re3SmallNet(nn.Module):
    @property
    def device(self):
        return next(self.parameters()).device
    def __init__(self, lstm_size=512, args=None):
        super(Re3SmallNet, self).__init__()
        self.lstm_size = lstm_size

        self.feature_extractor = nn.Sequential(
            ConvBlock(in_channels=3, out_channels=32, padding=3, kernel_size=7, stride=4),
            ConvBlock(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2),
            ConvBlock(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2),
            ConvBlock(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2),
            ConvBlock(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
        )

        self.transform = transforms.Compose(
            [
                transforms.Lambda(lambda x: x if len(x.shape) == 4 else pt_util.remove_dim(x, 1)),
                transforms.Lambda(lambda x: x.to(torch.float32)),
                transforms.Lambda(
                    lambda x: pt_util.normalize(
                        x,
                        mean=np.array([123.675, 116.28, 103.53])[np.newaxis, np.newaxis, np.newaxis, :],
                        std=np.array([58.395, 57.12, 57.375])[np.newaxis, np.newaxis, np.newaxis, :],
                    )
                ),
                transforms.Lambda(lambda x: x.permute(0, 3, 1, 2)),
            ]
        )

        self.fc6 = nn.Sequential(
            nn.Linear(50176, 2048),
            nn.ELU()
        )
        self.lstm1 = nn.LSTMCell(2048, self.lstm_size)
        self.lstm2 = nn.LSTMCell(2048 + self.lstm_size, self.lstm_size)
        self.fc_output = nn.Sequential(
            nn.Linear(self.lstm_size, self.lstm_size), nn.ELU(inplace=True), nn.Linear(self.lstm_size, 4)
        )
        self.learning_rate = None
        self.optimizer = None
        self.outputs = None
        self.lstm_state = None

    def forward(self, input, lstm_state=None):
        x = input.to(self.device, dtype=torch.float32)
        x = self.transform(x)
        x = self.feature_extractor(x)
        x = pt_util.split_axis(x, 0, -1, 2)
        x = pt_util.remove_dim(x, (2, 3, 4))

        fc6 = self.fc6(x)

        if lstm_state is None:
            outputs1, state1 = self.lstm1(fc6)
            outputs2, state2 = self.lstm2(torch.cat((fc6, outputs1), 1))
        else:
            outputs1, state1, outputs2, state2 = lstm_state
            outputs1, state1 = self.lstm1(fc6, (outputs1, state1))
            outputs2, state2 = self.lstm2(torch.cat((fc6, outputs1), 1), (outputs2, state2))

        self.lstm_state = (outputs1, state1, outputs2, state2)

        output = self.fc_output(outputs2)
        return output


In [None]:
class Re3Net(nn.Module):
    def __init__(self, lstm_size=1024, args=None):
        super(Re3Net, self).__init__()
        self.lstm_size = lstm_size
        self.conv = nn.ModuleList(
            [
                nn.Conv2d(3, 96, 11, stride=4, padding=0),
                nn.Conv2d(96, 256, 5, padding=2, groups=2),
                nn.Conv2d(256, 384, 3, padding=1),
                nn.Conv2d(384, 384, 3, padding=1, groups=2),
                nn.Conv2d(384, 256, 3, padding=1, groups=2),
            ]
        )
        self.lrn = nn.ModuleList(
            [
                nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75),
                nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75),
            ]
        )

        self.conv_skip = nn.ModuleList([nn.Conv2d(96, 16, 1), nn.Conv2d(256, 32, 1), nn.Conv2d(256, 64, 1), ])
        self.prelu_skip = nn.ModuleList([nn.PReLU(16), nn.PReLU(32), nn.PReLU(64)])

        self.fc6 = nn.Linear(74208, 2048)

        self.lstm1 = CaffeLSTMCell(2048, self.lstm_size)
        self.lstm2 = CaffeLSTMCell(2048 + self.lstm_size, self.lstm_size)

        self.lstm_state = None

        self.fc_output_out = nn.Linear(self.lstm_size, 4)

        self.transform = transforms.Compose(
            [
                transforms.Lambda(lambda x: x if len(x.shape) == 4 else pt_util.remove_dim(x, 1)),
                transforms.Lambda(lambda x: x.to(torch.float32)),
                transforms.Lambda(
                    lambda x: pt_util.normalize(
                        x,
                        mean=np.array([123.151630838, 115.902882574, 103.062623801], dtype=np.float32)[
                             np.newaxis, np.newaxis, np.newaxis, :
                             ],
                    )
                ),
                transforms.Lambda(lambda x: x.permute(0, 3, 1, 2)),
            ]
        )

    def forward(self, input, lstm_state=None):
        batch_size = input.shape[0]
        input = self.transform(input).to(device=self.device)
        conv1 = self.conv[0](input)
        pool1 = F.relu(F.max_pool2d(conv1, (3, 3), stride=2))
        lrn1 = self.lrn[0](pool1)

        conv1_skip = self.prelu_skip[0](self.conv_skip[0](lrn1))
        conv1_skip_flat = pt_util.remove_dim(conv1_skip, [2, 3])

        conv2 = self.conv[1](lrn1)
        pool2 = F.relu(F.max_pool2d(conv2, (3, 3), stride=2))
        lrn2 = self.lrn[1](pool2)

        conv2_skip = self.prelu_skip[1](self.conv_skip[1](lrn2))
        conv2_skip_flat = pt_util.remove_dim(conv2_skip, [2, 3])

        conv3 = F.relu(self.conv[2](lrn2))
        conv4 = F.relu(self.conv[3](conv3))
        conv5 = F.relu(self.conv[4](conv4))
        pool5 = F.relu(F.max_pool2d(conv5, (3, 3), stride=2))
        pool5_flat = pt_util.remove_dim(pool5, [2, 3])

        conv5_skip = self.prelu_skip[2](self.conv_skip[2](conv5))
        conv5_skip_flat = pt_util.remove_dim(conv5_skip, [2, 3])

        skip_concat = torch.cat([conv1_skip_flat, conv2_skip_flat, conv5_skip_flat, pool5_flat], 1)
        skip_concat = pt_util.split_axis(skip_concat, 0, -1, 2)
        reshaped = pt_util.remove_dim(skip_concat, 2)

        fc6 = F.relu(self.fc6(reshaped))

        if lstm_state is None:
            outputs1, state1 = self.lstm1(fc6)
            outputs2, state2 = self.lstm2(torch.cat((fc6, outputs1), 1))
        else:
            outputs1, state1, outputs2, state2 = lstm_state
            outputs1, state1 = self.lstm1(fc6, (outputs1, state1))
            outputs2, state2 = self.lstm2(torch.cat((fc6, outputs1), 1), (outputs2, state2))

        self.lstm_state = (outputs1, state1, outputs2, state2)

        fc_output_out = self.fc_output_out(outputs2)
        return fc_output_out

In [None]:
model= Re3SmallNet()
model.eval()

Re3SmallNet(
  (feature_extractor): Sequential(
    (0): ConvBlock(
      (conv): Conv2d(3, 32, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
      (bn): GroupNorm(32, 32, eps=1e-05, affine=True)
      (nonlinearity): ELU(alpha=1.0, inplace=True)
    )
    (1): ConvBlock(
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn): GroupNorm(32, 64, eps=1e-05, affine=True)
      (nonlinearity): ELU(alpha=1.0, inplace=True)
    )
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): ConvBlock(
      (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn): GroupNorm(32, 128, eps=1e-05, affine=True)
      (nonlinearity): ELU(alpha=1.0, inplace=True)
    )
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): ConvBlock(
      (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn): GroupNorm(32, 256, eps=1e-05, affine