In [1]:
!pip install fastai==1.0.60

Collecting fastai==1.0.60
  Downloading fastai-1.0.60-py3-none-any.whl (237 kB)
[K     |████████████████████████████████| 237 kB 4.1 MB/s 
Installing collected packages: fastai
  Attempting uninstall: fastai
    Found existing installation: fastai 1.0.61
    Uninstalling fastai-1.0.61:
      Successfully uninstalled fastai-1.0.61
Successfully installed fastai-1.0.60


In [2]:
!pip install pytorch_model_summary

Collecting pytorch_model_summary
  Downloading pytorch_model_summary-0.1.2-py3-none-any.whl (9.3 kB)
Installing collected packages: pytorch-model-summary
Successfully installed pytorch-model-summary-0.1.2


In [3]:
!git clone https://github.com/FangShancheng/ABINet.git

Cloning into 'ABINet'...
remote: Enumerating objects: 75, done.[K
remote: Counting objects: 100% (75/75), done.[K
remote: Compressing objects: 100% (60/60), done.[K
remote: Total 75 (delta 13), reused 64 (delta 10), pack-reused 0[K
Unpacking objects: 100% (75/75), done.


In [4]:
%cd ABINet

/content/ABINet


In [5]:
import torch
import torch.nn as nn
from fastai.vision import *

from modules.model import _default_tfmer_cfg
from modules.resnet import resnet45
from modules.transformer import (PositionalEncoding,
                                 TransformerEncoder,
                                 TransformerEncoderLayer)


class ResTranformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.resnet = resnet45()

        self.d_model = ifnone(config.model_vision_d_model, _default_tfmer_cfg['d_model'])
        nhead = ifnone(config.model_vision_nhead, _default_tfmer_cfg['nhead'])
        d_inner = ifnone(config.model_vision_d_inner, _default_tfmer_cfg['d_inner'])
        dropout = ifnone(config.model_vision_dropout, _default_tfmer_cfg['dropout'])
        activation = ifnone(config.model_vision_activation, _default_tfmer_cfg['activation'])
        num_layers = ifnone(config.model_vision_backbone_ln, 2)

        self.pos_encoder = PositionalEncoding(self.d_model, max_len=8*32)
        encoder_layer = TransformerEncoderLayer(d_model=self.d_model, nhead=nhead, 
                dim_feedforward=d_inner, dropout=dropout, activation=activation)
        self.transformer = TransformerEncoder(encoder_layer, num_layers)

    def forward(self, images):
        feature = self.resnet(images)
        n, c, h, w = feature.shape
        feature = feature.view(n, c, -1).permute(2, 0, 1)
        feature = self.pos_encoder(feature)
        feature = self.transformer(feature)
        feature = feature.permute(1, 2, 0).view(n, c, h, w)
        return feature

In [6]:
import math

import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo


def conv1x1(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv1x1(inplanes, planes)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes, stride)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers):
        self.inplanes = 32
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU(inplace=True)

        self.layer1 = self._make_layer(block, 32, layers[0], stride=2)
        self.layer2 = self._make_layer(block, 64, layers[1], stride=1)
        self.layer3 = self._make_layer(block, 128, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 256, layers[3], stride=1)
        self.layer5 = self._make_layer(block, 512, layers[4], stride=1)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        #print("====",self.inplanes, planes * block.expansion, planes, blocks, stride,  block.expansion)
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        return x


def resnet45():
    return ResNet(BasicBlock, [3, 4, 6, 6, 3])

In [7]:
import pytorch_model_summary
import torch
net = ResNet(BasicBlock, [3, 4, 6, 6, 3])
print(net)

ResNet(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(32, 32, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(32, eps=1e-05, 

In [8]:
print(pytorch_model_summary.summary(net, torch.zeros(1, 3, 32, 100 ), show_input = True))

------------------------------------------------------------------------
      Layer (type)          Input Shape         Param #     Tr. Param #
          Conv2d-1      [1, 3, 32, 100]             864             864
     BatchNorm2d-2     [1, 32, 32, 100]              64              64
            ReLU-3     [1, 32, 32, 100]               0               0
      BasicBlock-4     [1, 32, 32, 100]          11,456          11,456
      BasicBlock-5      [1, 32, 16, 50]          10,368          10,368
      BasicBlock-6      [1, 32, 16, 50]          10,368          10,368
      BasicBlock-7      [1, 32, 16, 50]          41,344          41,344
      BasicBlock-8      [1, 64, 16, 50]          41,216          41,216
      BasicBlock-9      [1, 64, 16, 50]          41,216          41,216
     BasicBlock-10      [1, 64, 16, 50]          41,216          41,216
     BasicBlock-11      [1, 64, 16, 50]         164,608         164,608
     BasicBlock-12      [1, 128, 8, 25]         164,352        

In [9]:
print(pytorch_model_summary.summary(net, torch.zeros(1, 3, 32, 100 ), show_input = True, show_hierarchical=True))

------------------------------------------------------------------------
      Layer (type)          Input Shape         Param #     Tr. Param #
          Conv2d-1      [1, 3, 32, 100]             864             864
     BatchNorm2d-2     [1, 32, 32, 100]              64              64
            ReLU-3     [1, 32, 32, 100]               0               0
      BasicBlock-4     [1, 32, 32, 100]          11,456          11,456
      BasicBlock-5      [1, 32, 16, 50]          10,368          10,368
      BasicBlock-6      [1, 32, 16, 50]          10,368          10,368
      BasicBlock-7      [1, 32, 16, 50]          41,344          41,344
      BasicBlock-8      [1, 64, 16, 50]          41,216          41,216
      BasicBlock-9      [1, 64, 16, 50]          41,216          41,216
     BasicBlock-10      [1, 64, 16, 50]          41,216          41,216
     BasicBlock-11      [1, 64, 16, 50]         164,608         164,608
     BasicBlock-12      [1, 128, 8, 25]         164,352        

In [10]:
from tensorflow.keras.applications.resnet50 import ResNet50
resnet = ResNet50(weights='imagenet')
resnet.summary()


Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels.h5
Model: "resnet50"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv1_pad (ZeroPadding2D)      (None, 230, 230, 3)  0           ['input_1[0][0]']                
                                                                                                  
 conv1_conv (Conv2D)            (None, 112, 112, 64  9472        ['conv1_pad[0][0]']              
                                )                    

In [11]:
_default_tfmer_cfg = dict(d_model=512, nhead=8, d_inner=2048, # 1024
                          dropout=0.1, activation='relu')
_default_tfmer_cfg

{'activation': 'relu',
 'd_inner': 2048,
 'd_model': 512,
 'dropout': 0.1,
 'nhead': 8}

In [12]:
resnet = resnet45()
feature = resnet(torch.zeros(1, 3, 32, 100 ))
feature.shape

torch.Size([1, 512, 8, 25])

In [13]:
n, c, h, w = feature.shape

In [14]:
feature.view(n, c, -1).shape

torch.Size([1, 512, 200])

In [15]:
feature = feature.view(n, c, -1).permute(2, 0, 1)
feature.shape

torch.Size([200, 1, 512])

In [16]:
class BasicBlock():
    expansion = 2

block = BasicBlock()
print(block.expansion)

2


In [17]:
class BasicBlock():
    expansion = 2
    def __init__(self):
        self.expansion1 = 3

block = BasicBlock()
print(block.expansion)
print(block.expansion1)

block.expansion = 4
block.expansion1 = 5
print(block.expansion)
print(block.expansion1)

2
3
4
5


In [18]:
def BasicBlock():
    expansion = 2
    return expansion

block = BasicBlock()
#print(block.expansion)

# Transformer

In [19]:
class PositionalEncoding(nn.Module):
    r"""Inject some information about the relative or absolute position of the tokens
        in the sequence. The positional encodings have the same dimension as
        the embeddings, so that the two can be summed. Here, we use sine and cosine
        functions of different frequencies.
    .. math::
        \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
        \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
        \text{where pos is the word position and i is the embed idx)
    Args:
        d_model: the embed dim (required).
        dropout: the dropout value (default=0.1).
        max_len: the max. length of the incoming sequence (default=5000).
    Examples:
        >>> pos_encoder = PositionalEncoding(d_model)
    """

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        r"""Inputs of forward function
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        Examples:
            >>> output = pos_encoder(x)
        """
        #print("x:::", x)
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [20]:
pos_encoder = PositionalEncoding(512, max_len=8*32)
pos_encoder

PositionalEncoding(
  (dropout): Dropout(p=0.1, inplace=False)
)

In [21]:
feature = pos_encoder(feature)
feature

tensor([[[ 0.0000e+00,  1.1111e+00,  0.0000e+00,  ...,  1.1111e+00,
           0.0000e+00,  1.1111e+00]],

        [[ 9.3497e-01,  6.0034e-01,  9.1317e-01,  ...,  1.1111e+00,
           1.1518e-04,  1.1111e+00]],

        [[ 1.0103e+00, -0.0000e+00,  1.0405e+00,  ...,  0.0000e+00,
           2.3036e-04,  1.1111e+00]],

        ...,

        [[ 8.8423e-01, -0.0000e+00,  1.1107e+00,  ...,  1.1109e+00,
           2.2689e-02,  0.0000e+00]],

        [[-8.8421e-02, -1.1076e+00,  0.0000e+00,  ...,  1.1109e+00,
           2.2804e-02,  0.0000e+00]],

        [[-9.7978e-01, -0.0000e+00, -3.6057e-01,  ...,  1.1109e+00,
           2.2919e-02,  1.1109e+00]]], grad_fn=<MulBackward0>)

In [22]:
max_len = 8*32
d_model = 512
pe = torch.zeros(max_len, d_model)
pe.shape

torch.Size([256, 512])

In [23]:
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
position.shape

torch.Size([256, 1])

# TransformerEncoderLayer

In [24]:
def _get_activation_fn(activation):
    if activation == "relu":
        return F.relu
    elif activation == "gelu":
        return F.gelu

    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))

In [25]:
class TransformerEncoderLayer(Module):


    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 
                 activation="relu", debug=False):
        super(TransformerEncoderLayer, self).__init__()
        self.debug = debug
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = Linear(d_model, dim_feedforward)
        self.dropout = Dropout(dropout)
        self.linear2 = Linear(dim_feedforward, d_model)

        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def __setstate__(self, state):
        if 'activation' not in state:
            state['activation'] = F.relu
        super(TransformerEncoderLayer, self).__setstate__(state)

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor
        r"""Pass the input through the encoder layer.
        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).
        Shape:
            see the docs in Transformer class.
        """
        src2, attn = self.self_attn(src, src, src, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)
        if self.debug: self.attn = attn
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        
        return src

In [26]:
_qkv_same_embed_dim = 512 == 512 and 512 == 512
_qkv_same_embed_dim

True

In [27]:
from torch.nn import Parameter
import torch
embed_dim= 512
in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
in_proj_weight.shape

torch.Size([1536, 512])

In [28]:
if not True:
    print("1")
else:
    print("2")

2


In [29]:
in_proj_bias = Parameter(torch.empty(3 * embed_dim))
in_proj_bias.shape

torch.Size([1536])

# MultiheadAttention

In [30]:
import copy
import math
import warnings
from typing import Optional

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import Dropout, LayerNorm, Linear, Module, ModuleList, Parameter
from torch.nn import functional as F
from torch.nn.init import constant_, xavier_uniform_
class MultiheadAttention(Module):
  

    __constants__ = ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight']

    def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
        super(MultiheadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim

        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

        if self._qkv_same_embed_dim is False:
            self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
            self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
            self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
            self.register_parameter('in_proj_weight', None)
        else:
            self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
            self.register_parameter('q_proj_weight', None)
            self.register_parameter('k_proj_weight', None)
            self.register_parameter('v_proj_weight', None)

        if bias:
            self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
        else:
            self.register_parameter('in_proj_bias', None)
        self.out_proj = Linear(embed_dim, embed_dim, bias=bias)

        if add_bias_kv:
            self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
            self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
        else:
            self.bias_k = self.bias_v = None

        self.add_zero_attn = add_zero_attn

        self._reset_parameters()

    def _reset_parameters(self):
        if self._qkv_same_embed_dim:
            xavier_uniform_(self.in_proj_weight)
        else:
            xavier_uniform_(self.q_proj_weight)
            xavier_uniform_(self.k_proj_weight)
            xavier_uniform_(self.v_proj_weight)

        if self.in_proj_bias is not None:
            constant_(self.in_proj_bias, 0.)
            constant_(self.out_proj.bias, 0.)
        if self.bias_k is not None:
            xavier_normal_(self.bias_k)
        if self.bias_v is not None:
            xavier_normal_(self.bias_v)

    def __setstate__(self, state):
        # Support loading old MultiheadAttention checkpoints generated by v1.1.0
        if '_qkv_same_embed_dim' not in state:
            state['_qkv_same_embed_dim'] = True

        super(MultiheadAttention, self).__setstate__(state)

    def forward(self, query, key, value, key_padding_mask=None,
                need_weights=True, attn_mask=None):
        
        if not self._qkv_same_embed_dim:
            return multi_head_attention_forward(
                query, key, value, self.embed_dim, self.num_heads,
                self.in_proj_weight, self.in_proj_bias,
                self.bias_k, self.bias_v, self.add_zero_attn,
                self.dropout, self.out_proj.weight, self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights,
                attn_mask=attn_mask, use_separate_proj_weight=True,
                q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
                v_proj_weight=self.v_proj_weight)
        else:
            return multi_head_attention_forward(
                query, key, value, self.embed_dim, self.num_heads,
                self.in_proj_weight, self.in_proj_bias,
                self.bias_k, self.bias_v, self.add_zero_attn,
                self.dropout, self.out_proj.weight, self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights,
                attn_mask=attn_mask)

In [31]:
encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8, 
                dim_feedforward=2048, dropout=0.1, activation='relu')
encoder_layer

TransformerEncoderLayer(
  (self_attn): MultiheadAttention(
    (out_proj): Linear(in_features=512, out_features=512, bias=True)
  )
  (linear1): Linear(in_features=512, out_features=2048, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (linear2): Linear(in_features=2048, out_features=512, bias=True)
  (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (dropout1): Dropout(p=0.1, inplace=False)
  (dropout2): Dropout(p=0.1, inplace=False)
)

In [32]:
def multi_head_attention_forward(query,                           # type: Tensor
                                 key,                             # type: Tensor
                                 value,                           # type: Tensor
                                 embed_dim_to_check,              # type: int
                                 num_heads,                       # type: int
                                 in_proj_weight,                  # type: Tensor
                                 in_proj_bias,                    # type: Tensor
                                 bias_k,                          # type: Optional[Tensor]
                                 bias_v,                          # type: Optional[Tensor]
                                 add_zero_attn,                   # type: bool
                                 dropout_p,                       # type: float
                                 out_proj_weight,                 # type: Tensor
                                 out_proj_bias,                   # type: Tensor
                                 training=True,                   # type: bool
                                 key_padding_mask=None,           # type: Optional[Tensor]
                                 need_weights=True,               # type: bool
                                 attn_mask=None,                  # type: Optional[Tensor]
                                 use_separate_proj_weight=False,  # type: bool
                                 q_proj_weight=None,              # type: Optional[Tensor]
                                 k_proj_weight=None,              # type: Optional[Tensor]
                                 v_proj_weight=None,              # type: Optional[Tensor]
                                 static_k=None,                   # type: Optional[Tensor]
                                 static_v=None                    # type: Optional[Tensor]
                                 ):

    tgt_len, bsz, embed_dim = query.size()
    assert embed_dim == embed_dim_to_check
    assert key.size() == value.size()

    head_dim = embed_dim // num_heads
    assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
    scaling = float(head_dim) ** -0.5

    if not use_separate_proj_weight:
        if torch.equal(query, key) and torch.equal(key, value):
            # self-attention
            q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)

        elif torch.equal(key, value):
            # encoder-decoder attention
            # This is inline in_proj function with in_proj_weight and in_proj_bias
            _b = in_proj_bias
            _start = 0
            _end = embed_dim
            _w = in_proj_weight[_start:_end, :]
            if _b is not None:
                _b = _b[_start:_end]
            q = F.linear(query, _w, _b)

            if key is None:
                assert value is None
                k = None
                v = None
            else:

                # This is inline in_proj function with in_proj_weight and in_proj_bias
                _b = in_proj_bias
                _start = embed_dim
                _end = None
                _w = in_proj_weight[_start:, :]
                if _b is not None:
                    _b = _b[_start:]
                k, v = F.linear(key, _w, _b).chunk(2, dim=-1)

        else:
            # This is inline in_proj function with in_proj_weight and in_proj_bias
            _b = in_proj_bias
            _start = 0
            _end = embed_dim
            _w = in_proj_weight[_start:_end, :]
            if _b is not None:
                _b = _b[_start:_end]
            q = F.linear(query, _w, _b)

            # This is inline in_proj function with in_proj_weight and in_proj_bias
            _b = in_proj_bias
            _start = embed_dim
            _end = embed_dim * 2
            _w = in_proj_weight[_start:_end, :]
            if _b is not None:
                _b = _b[_start:_end]
            k = F.linear(key, _w, _b)

            # This is inline in_proj function with in_proj_weight and in_proj_bias
            _b = in_proj_bias
            _start = embed_dim * 2
            _end = None
            _w = in_proj_weight[_start:, :]
            if _b is not None:
                _b = _b[_start:]
            v = F.linear(value, _w, _b)
    else:
        q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
        len1, len2 = q_proj_weight_non_opt.size()
        assert len1 == embed_dim and len2 == query.size(-1)

        k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
        len1, len2 = k_proj_weight_non_opt.size()
        assert len1 == embed_dim and len2 == key.size(-1)

        v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
        len1, len2 = v_proj_weight_non_opt.size()
        assert len1 == embed_dim and len2 == value.size(-1)

        if in_proj_bias is not None:
            q = F.linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
            k = F.linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)])
            v = F.linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):])
        else:
            q = F.linear(query, q_proj_weight_non_opt, in_proj_bias)
            k = F.linear(key, k_proj_weight_non_opt, in_proj_bias)
            v = F.linear(value, v_proj_weight_non_opt, in_proj_bias)
    q = q * scaling

    if attn_mask is not None:
        assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \
            attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \
            'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype)
        if attn_mask.dtype == torch.uint8:
            warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
            attn_mask = attn_mask.to(torch.bool)

        if attn_mask.dim() == 2:
            attn_mask = attn_mask.unsqueeze(0)
            if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
                raise RuntimeError('The size of the 2D attn_mask is not correct.')
        elif attn_mask.dim() == 3:
            if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
                raise RuntimeError('The size of the 3D attn_mask is not correct.')
        else:
            raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
        # attn_mask's dim is 3 now.

    # # convert ByteTensor key_padding_mask to bool
    # if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
    #     warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
    #     key_padding_mask = key_padding_mask.to(torch.bool)

    if bias_k is not None and bias_v is not None:
        if static_k is None and static_v is None:
            k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
            v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
            if attn_mask is not None:
                attn_mask = pad(attn_mask, (0, 1))
            if key_padding_mask is not None:
                key_padding_mask = pad(key_padding_mask, (0, 1))
        else:
            assert static_k is None, "bias cannot be added to static key."
            assert static_v is None, "bias cannot be added to static value."
    else:
        assert bias_k is None
        assert bias_v is None

 
    q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    if k is not None:
        k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
    if v is not None:
        v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)


    if static_k is not None:
        assert static_k.size(0) == bsz * num_heads
        assert static_k.size(2) == head_dim
        k = static_k

    if static_v is not None:
        assert static_v.size(0) == bsz * num_heads
        assert static_v.size(2) == head_dim
        v = static_v

    src_len = k.size(1)

    if key_padding_mask is not None:
        assert key_padding_mask.size(0) == bsz
        assert key_padding_mask.size(1) == src_len

  
    if add_zero_attn:
        src_len += 1
        k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
        v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
        if attn_mask is not None:
            attn_mask = pad(attn_mask, (0, 1))
        if key_padding_mask is not None:
            key_padding_mask = pad(key_padding_mask, (0, 1))

    attn_output_weights = torch.bmm(q, k.transpose(1, 2))
    assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]


    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_output_weights.masked_fill_(attn_mask, float('-inf'))
        else:
            attn_output_weights += attn_mask


    if key_padding_mask is not None:
        attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
        attn_output_weights = attn_output_weights.masked_fill(
            key_padding_mask.unsqueeze(1).unsqueeze(2),
            float('-inf'),
        )
        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)

    attn_output_weights = F.softmax(
        attn_output_weights, dim=-1)
    attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training)
    attn_output = torch.bmm(attn_output_weights, v)
    assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
    attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
    attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)

    if need_weights:
        # average attention weights over heads
        attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
        return attn_output, attn_output_weights.sum(dim=1) / num_heads
    else:
        return attn_output, None

In [33]:
embed_dim = 512
num_heads = 8
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
multihead_attn

MultiheadAttention(
  (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
)

In [34]:
 #F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
print(torch.arange(7))
print(torch.arange(7).chunk(3, dim=-1) )
print(torch.arange(7).chunk(3, dim=-1) )

tensor([0, 1, 2, 3, 4, 5, 6])
(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6]))
(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6]))


In [35]:
head_dim = 8
scaling = float(head_dim) ** -0.5
scaling

0.3535533905932738

In [36]:
attn_output_weights = torch.Tensor([[1,2], [3,4]])
attn_output_weights.sum(dim=1)

tensor([3., 7.])

# TransformerEncoder

In [37]:
class TransformerEncoder(Module):
    r"""TransformerEncoder is a stack of N encoder layers
    Args:
        encoder_layer: an instance of the TransformerEncoderLayer() class (required).
        num_layers: the number of sub-encoder-layers in the encoder (required).
        norm: the layer normalization component (optional).
    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
        >>> src = torch.rand(10, 32, 512)
        >>> out = transformer_encoder(src)
    """
    __constants__ = ['norm']

    def __init__(self, encoder_layer, num_layers, norm=None):
        super(TransformerEncoder, self).__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, src, mask=None, src_key_padding_mask=None):
        # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor
        r"""Pass the input through the encoder layers in turn.
        Args:
            src: the sequence to the encoder (required).
            mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).
        Shape:
            see the docs in Transformer class.
        """
        output = src

        for i, mod in enumerate(self.layers):
            print("mod:",mod , src.shape, output.shape, mask ,  src_key_padding_mask)
            output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)

        if self.norm is not None:
            output = self.norm(output)

        return output

In [38]:
def _get_clones(module, N):
    return ModuleList([copy.deepcopy(module) for i in range(N)])

In [39]:
num_layers = 3
transformer = TransformerEncoder(encoder_layer, num_layers)
transformer

TransformerEncoder(
  (layers): ModuleList(
    (0): TransformerEncoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): Linear(in_features=512, out_features=512, bias=True)
      )
      (linear1): Linear(in_features=512, out_features=2048, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (linear2): Linear(in_features=2048, out_features=512, bias=True)
      (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (dropout1): Dropout(p=0.1, inplace=False)
      (dropout2): Dropout(p=0.1, inplace=False)
    )
    (1): TransformerEncoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): Linear(in_features=512, out_features=512, bias=True)
      )
      (linear1): Linear(in_features=512, out_features=2048, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (linear2): Linear(in_features=2048, out_features=512, bias=True)
      (norm1): LayerNorm((512,

In [40]:
feature = transformer(feature)
feature

mod: TransformerEncoderLayer(
  (self_attn): MultiheadAttention(
    (out_proj): Linear(in_features=512, out_features=512, bias=True)
  )
  (linear1): Linear(in_features=512, out_features=2048, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (linear2): Linear(in_features=2048, out_features=512, bias=True)
  (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (dropout1): Dropout(p=0.1, inplace=False)
  (dropout2): Dropout(p=0.1, inplace=False)
) torch.Size([200, 1, 512]) torch.Size([200, 1, 512]) None None
mod: TransformerEncoderLayer(
  (self_attn): MultiheadAttention(
    (out_proj): Linear(in_features=512, out_features=512, bias=True)
  )
  (linear1): Linear(in_features=512, out_features=2048, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (linear2): Linear(in_features=2048, out_features=512, bias=True)
  (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((5

tensor([[[ 0.5752,  0.4777, -1.2485,  ...,  1.4943, -0.0234,  0.5956]],

        [[ 1.2946, -0.2978, -1.0299,  ...,  0.5991,  0.0688,  0.1283]],

        [[ 1.4604, -0.0059,  0.4085,  ..., -1.3981, -0.1582,  0.4778]],

        ...,

        [[ 1.1164, -0.4783,  0.2803,  ...,  0.6602,  0.5824,  0.2194]],

        [[ 0.0337, -1.1962, -0.1274,  ...,  0.3636,  0.3371,  0.2454]],

        [[-0.9521,  0.1579, -1.1347,  ...,  0.8674,  0.3711,  0.6302]]],
       grad_fn=<NativeLayerNormBackward0>)

In [41]:
print(feature.shape)
# n = 1
# c = 512
# h = 8
# w = 25
feature.permute(1, 2, 0).view(n, c, h, w).shape

torch.Size([200, 1, 512])


torch.Size([1, 512, 8, 25])

In [42]:
class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])

    def forward(self, x):
        # ModuleList can act as an iterable, or be indexed using ints
        for i, l in enumerate(self.linears):
            x = self.linears[i // 2](x) + l(x)
        return x

model = MyModule()
model

MyModule(
  (linears): ModuleList(
    (0): Linear(in_features=10, out_features=10, bias=True)
    (1): Linear(in_features=10, out_features=10, bias=True)
    (2): Linear(in_features=10, out_features=10, bias=True)
    (3): Linear(in_features=10, out_features=10, bias=True)
    (4): Linear(in_features=10, out_features=10, bias=True)
    (5): Linear(in_features=10, out_features=10, bias=True)
    (6): Linear(in_features=10, out_features=10, bias=True)
    (7): Linear(in_features=10, out_features=10, bias=True)
    (8): Linear(in_features=10, out_features=10, bias=True)
    (9): Linear(in_features=10, out_features=10, bias=True)
  )
)

In [43]:
bias = True
out_proj = Linear(512, 512, bias=bias)
out_proj.weight.shape , out_proj.bias.shape

(torch.Size([512, 512]), torch.Size([512]))

In [44]:
m = nn.Linear(1, 2)
m.weight.shape, m.bias.shape

(torch.Size([2, 1]), torch.Size([2]))

# PositionAttention

In [45]:
def encoder_layer(in_c, out_c, k=3, s=2, p=1):
    return nn.Sequential(nn.Conv2d(in_c, out_c, k, s, p),
                         nn.BatchNorm2d(out_c),
                         nn.ReLU(True))

def decoder_layer(in_c, out_c, k=3, s=1, p=1, mode='nearest', scale_factor=None, size=None):
    align_corners = None if mode=='nearest' else True
    return nn.Sequential(nn.Upsample(size=size, scale_factor=scale_factor, 
                                     mode=mode, align_corners=align_corners),
                         nn.Conv2d(in_c, out_c, k, s, p),
                         nn.BatchNorm2d(out_c),
                         nn.ReLU(True))

In [54]:

class PositionAttention(nn.Module):
    def __init__(self, max_length, in_channels=512, num_channels=64, 
                 h=8, w=32, mode='nearest', **kwargs):
        super().__init__()
        self.max_length = max_length
        self.k_encoder = nn.Sequential(
            encoder_layer(in_channels, num_channels, s=(1, 2)),
            encoder_layer(num_channels, num_channels, s=(2, 2)),
            encoder_layer(num_channels, num_channels, s=(2, 2)),
            encoder_layer(num_channels, num_channels, s=(2, 2))
        )
        self.k_decoder = nn.Sequential(
            decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
            decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
            decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
            decoder_layer(num_channels, in_channels, size=(h, w), mode=mode)
        )

        self.pos_encoder = PositionalEncoding(in_channels, dropout=0, max_len=max_length)
        self.project = nn.Linear(in_channels, in_channels)

    def forward(self, x):
        N, E, H, W = x.size()
        k, v = x, x  # (N, E, H, W)

        # calculate key vector
        features = []
        for i in range(0, len(self.k_encoder)):
            k = self.k_encoder[i](k)
            features.append(k)
        for i in range(0, len(self.k_decoder) - 1):
            k = self.k_decoder[i](k)
            k = k + features[len(self.k_decoder) - 2 - i]
        k = self.k_decoder[-1](k)

        # calculate query vector
        # TODO q=f(q,k)
        zeros = x.new_zeros((self.max_length, N, E))  # (T, N, E)
        q = self.pos_encoder(zeros)  # (T, N, E)
        q = q.permute(1, 0, 2)  # (N, T, E)
        q = self.project(q)  # (N, T, E)
        
        # calculate attention
        attn_scores = torch.bmm(q, k.flatten(2, 3))  # (N, T, (H*W))
        attn_scores = attn_scores / (E ** 0.5)
        attn_scores = torch.softmax(attn_scores, dim=-1)

        v = v.permute(0, 2, 3, 1).view(N, -1, E)  # (N, (H*W), E)
        attn_vecs = torch.bmm(attn_scores, v)  # (N, T, E)

        return attn_vecs, attn_scores.view(N, -1, H, W)

In [47]:
PositionAttention(26 )

PositionAttention(
  (k_encoder): Sequential(
    (0): Sequential(
      (0): Conv2d(512, 64, kernel_size=(3, 3), stride=(1, 2), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (1): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (2): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (3): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
  )
  (k_decoder): Sequential(
    (0): Sequential(
 

In [80]:
k_encoder = nn.Sequential(
            encoder_layer(512, 64, s=(1, 2)),
            encoder_layer(64, 64, s=(2, 2)),
            encoder_layer(64, 64, s=(2, 2)),
            encoder_layer(64, 64, s=(2, 2))
        )

k_decoder = nn.Sequential(
            encoder_layer(512, 64, s=(1, 2)),
            encoder_layer(64, 64, s=(2, 2)),
            encoder_layer(64, 64, s=(2, 2)),
            encoder_layer(64, 64, s=(2, 2))
        )

for i in range(0, len(k_decoder) - 1):               
    #k = k_decoder[i](k)            
    print(len(k_decoder) - 2 - i)
    #k = k + features[len(self.k_decoder) - 2 - i]

2
1
0


In [90]:
t = torch.tensor([[[[1, 2, 4],
                   [3, 4, 3]]],
                 [[[5, 6, 1],
                   [7, 8, 1]]]])
print(t.shape)
t.flatten(2, 3)
print(t.flatten(2, 3).shape)

torch.Size([2, 1, 2, 3])
torch.Size([2, 1, 6])


In [96]:
t = torch.tensor([[[1, 2, 4],
                   [3, 4, 3]],
                 [[5, 6, 1],
                   [7, 8, 1]]])
print(t.shape)
#t.view(N, -1, H, W)
print(t.view(2, -1, 2, 3).shape)

torch.Size([2, 2, 3])
torch.Size([2, 1, 2, 3])


In [123]:
def _get_length(logit, dim=-1):
        """ Greed decoder to obtain length from logit"""
        out = (logit.argmax(dim=-1) == 0 )
        abn = out.any(dim)
        print(out)
        print(out.cumsum(dim))
        out = ((out.cumsum(dim) == 1) & out).max(dim)[1]
        out = out + 1  # additional end token
        out = torch.where(abn, out, out.new_tensor(logit.shape[1]))
        return out
logit = torch.Tensor([[0.8808, 0.1192, 0.9933],
        [0.1192, 0.8808, 0.0067]])
_get_length(logit)

tensor([False, False])
tensor([0, 0])
tensor(1)


tensor(3)

In [117]:
x = torch.tensor([[[5,2,7]],[[3,4,2]]],dtype=torch.float)
print(F.softmax(x,dim=0))
print(F.softmax(x,dim=-1))
#print(F.softmax(x,dim=2).shape)

tensor([[[0.8808, 0.1192, 0.9933]],

        [[0.1192, 0.8808, 0.0067]]])
tensor([[[0.1185, 0.0059, 0.8756]],

        [[0.2447, 0.6652, 0.0900]]])


In [49]:
import logging
import os
import time

import cv2
import numpy as np
import torch
import yaml
from matplotlib import colors
from matplotlib import pyplot as plt
from torch import Tensor, nn

class Config1(object):

    def __init__(self, config_path, host=True):
        def __dict2attr(d, prefix=''):
            for k, v in d.items():
                if isinstance(v, dict):
                    __dict2attr(v, f'{prefix}{k}_')
                else:
                    if k == 'phase':
                        assert v in ['train', 'test']
                    if k == 'stage':
                        assert v in ['pretrain-vision', 'pretrain-language',
                                     'train-semi-super', 'train-super']
                    self.__setattr__(f'{prefix}{k}', v)

        assert os.path.exists(config_path), '%s does not exists!' % config_path
        with open(config_path) as file:
            #config_dict = yaml.load(file, Loader=yaml.FullLoader)
            config_dict = yaml.load(file)
        with open('configs/template.yaml') as file:
            #default_config_dict = yaml.load(file, Loader=yaml.FullLoader)
            default_config_dict = yaml.load(file)
        __dict2attr(default_config_dict)
        __dict2attr(config_dict)
        self.global_workdir = os.path.join(self.global_workdir, self.global_name)

    def __getattr__(self, item):
        attr = self.__dict__.get(item)
        if attr is None:
            attr = dict()
            prefix = f'{item}_'
            for k, v in self.__dict__.items():
                if k.startswith(prefix):
                    n = k.replace(prefix, '')
                    attr[n] = v
            return attr if len(attr) > 0 else None
        else:
            return attr

    def __repr__(self):
        str = 'ModelConfig(\n'
        for i, (k, v) in enumerate(sorted(vars(self).items())):
            str += f'\t({i}): {k} = {v}\n'
        str += ')'
        return str
configs = "configs/pretrain_vision_model.yaml"
config = Config1(configs)
print(config)

ModelConfig(
	(0): dataset_case_sensitive = False
	(1): dataset_charset_path = data/charset_36.txt
	(2): dataset_data_aug = True
	(3): dataset_eval_case_sensitive = False
	(4): dataset_image_height = 32
	(5): dataset_image_width = 128
	(6): dataset_max_length = 25
	(7): dataset_multiscales = False
	(8): dataset_num_workers = 14
	(9): dataset_one_hot_y = True
	(10): dataset_pin_memory = True
	(11): dataset_smooth_factor = 0.1
	(12): dataset_smooth_label = False
	(13): dataset_test_batch_size = 384
	(14): dataset_test_roots = ['data/evaluation/IIIT5k_3000', 'data/evaluation/SVT', 'data/evaluation/SVTP', 'data/evaluation/IC13_857', 'data/evaluation/IC15_1811', 'data/evaluation/CUTE80']
	(15): dataset_train_batch_size = 384
	(16): dataset_train_roots = ['data/training/MJ/MJ_train/', 'data/training/MJ/MJ_test/', 'data/training/MJ/MJ_valid/', 'data/training/ST']
	(17): dataset_use_sm = False
	(18): global_name = pretrain-vision-model
	(19): global_phase = train
	(20): global_seed = None
	(21

In [50]:
configs = "configs/pretrain_vision_model.yaml"
config = Config1(configs)
config

ModelConfig(
	(0): dataset_case_sensitive = False
	(1): dataset_charset_path = data/charset_36.txt
	(2): dataset_data_aug = True
	(3): dataset_eval_case_sensitive = False
	(4): dataset_image_height = 32
	(5): dataset_image_width = 128
	(6): dataset_max_length = 25
	(7): dataset_multiscales = False
	(8): dataset_num_workers = 14
	(9): dataset_one_hot_y = True
	(10): dataset_pin_memory = True
	(11): dataset_smooth_factor = 0.1
	(12): dataset_smooth_label = False
	(13): dataset_test_batch_size = 384
	(14): dataset_test_roots = ['data/evaluation/IIIT5k_3000', 'data/evaluation/SVT', 'data/evaluation/SVTP', 'data/evaluation/IC13_857', 'data/evaluation/IC15_1811', 'data/evaluation/CUTE80']
	(15): dataset_train_batch_size = 384
	(16): dataset_train_roots = ['data/training/MJ/MJ_train/', 'data/training/MJ/MJ_test/', 'data/training/MJ/MJ_valid/', 'data/training/ST']
	(17): dataset_use_sm = False
	(18): global_name = pretrain-vision-model
	(19): global_phase = train
	(20): global_seed = None
	(21

In [68]:
import logging
import torch.nn as nn
from fastai.vision import *

from modules.attention import *
from modules.backbone import ResTranformer
from modules.model import Model
from modules.resnet import resnet45

class BaseVision(Model):
    def __init__(self, config):
        super().__init__(config)
        self.loss_weight = ifnone(config.model_vision_loss_weight, 1.0)
        self.out_channels = ifnone(config.model_vision_d_model, 512)

        if config.model_vision_backbone == 'transformer':
            self.backbone = ResTranformer(config)
        else: self.backbone = resnet45()
        
        if config.model_vision_attention == 'position':
            mode = ifnone(config.model_vision_attention_mode, 'nearest')
            self.attention = PositionAttention(
                max_length=config.dataset_max_length + 1,  # additional stop token
                mode=mode,
            )
        elif config.model_vision_attention == 'attention':
            self.attention = Attention(
                max_length=config.dataset_max_length + 1,  # additional stop token
                n_feature=8*32,
            )
        else:
            raise Exception(f'{config.model_vision_attention} is not valid.')
        self.cls = nn.Linear(self.out_channels, self.charset.num_classes)

        if config.model_vision_checkpoint is not None:
            logging.info(f'Read vision model from {config.model_vision_checkpoint}.')
            self.load(config.model_vision_checkpoint)

    def forward(self, images, *args):
        print("111111111111============================")
        features = self.backbone(images)  # (N, E, H, W)
        attn_vecs, attn_scores = self.attention(features)  # (N, T, E), (N, T, H, W)
        logits = self.cls(attn_vecs) # (N, T, C)
        pt_lengths = self._get_length(logits)

        return {'feature': attn_vecs, 'logits': logits, 'pt_lengths': pt_lengths,
                'attn_scores': attn_scores, 'loss_weight':self.loss_weight, 'name': 'vision'}

In [71]:
def _get_model(config):
    import importlib
    names = config.model_name.split('.')
    module_name, class_name = '.'.join(names[:-1]), names[-1]
    cls = getattr(importlib.import_module(module_name), class_name)
    model = cls(config)
    logging.info(model)
    return model

model = _get_model(config)
model

BaseVision(
  (backbone): ResTranformer(
    (resnet): ResNet(
      (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(32, 32, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock

In [63]:
BaseVision(config)

BaseVision(
  (backbone): ResTranformer(
    (resnet): ResNet(
      (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(32, 32, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock

TypeError: ignored