<a href="https://colab.research.google.com/github/jjjonathan14/crowdhuman_tensorRT/blob/main/Copy_of_mnsit_2(2)(2)(1).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [83]:
!pip install einops

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [84]:

# !pip install einops
import random

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
import numpy as np
import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
import torch.nn as nn
import torch
import math
import numpy as np
import matplotlib.pyplot as plt


def batch_plot(img,n, h, w):
  fig = plt.figure(figsize=(n, n))
  columns = n
  rows = n
  for i in range(1, columns*rows):

      fig.add_subplot(rows, columns, i)
      plt.imshow(img[i])
  plt.show()


In [85]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 1, patch_size: int = 4, emb_size: int = 784, img_size: int = 28):
        self.patch_size = patch_size
        super().__init__()
        self.projection = nn.Sequential(
            # using a conv layer instead of a linear one -> performance gains
            nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
            Rearrange('b e (h) (w) -> b (h w) e'),
        )
        self.cls_token = nn.Parameter(torch.randn(1 ,1, emb_size))
        self.positions = nn.Parameter(torch.randn((img_size // patch_size) **2 + 1, emb_size))

    def forward(self, x: Tensor) -> Tensor:
        b, _, _, _ = x.shape

        x = self.projection(x)
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
        # prepend the cls token to the input
        x = torch.cat([cls_tokens, x], dim=1)
     
        # add position embedding
        x += self.positions
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0.1):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        # fuse the queries, keys and values in one matrix
        self.qkv = nn.Linear(emb_size, emb_size * 3)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)

    def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
        # split keys, queries and values in num_heads
       
        qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
        queries, keys, values = qkv[0], qkv[1], qkv[2]
        # sum up over the last axis
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)

        scaling = self.emb_size ** ( 1 /2)
        att = F.softmax(energy, dim=-1) / scaling
        att = self.att_drop(att)
        # sum up over the third axis
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out

class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x

class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size: int =768, expansion: int = 4, drop_p: float = 0.):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )

class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size: int = 768,
                 drop_p: float = 0.2,
                 forward_expansion: int = 4,
                 forward_drop_p: float = 0.1,
                 ** kwargs):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, **kwargs),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )
            ))

class TransformerEncoder(nn.Sequential):
    def __init__(self, depth: int = 12, **kwargs):
        super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])

class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size: int = 768, n_classes: int = 1000):
        super().__init__(
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(emb_size),
            nn.Linear(emb_size, n_classes))

class ViT(nn.Sequential):
    def __init__(self,
                 in_channels: int = 64,
                 patch_size: int = 16,
                 emb_size: int = 768,
                 img_size: int = 56,
                 depth: int = 12,
                 n_classes: int = 1000,
                 **kwargs):
        super().__init__(
            PatchEmbedding(in_channels, patch_size, emb_size, img_size),
            TransformerEncoder(depth, emb_size=emb_size, **kwargs),
            ClassificationHead(emb_size, n_classes)
        )

In [86]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import cv2
from tqdm import tqdm
import time
import os
from torch.utils.tensorboard import SummaryWriter



batch_size = 32
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               #torchvision.transforms.CenterCrop((24,24)),
                                torchvision.transforms.RandomRotation(10),
                                torchvision.transforms.RandomAffine(10),
                               #torchvision.transforms.Resize((28,28)),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,)),
                             ])),
  batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                              #torchvision.transforms.CenterCrop((24,24)),
                               #torchvision.transforms.RandomRotation(10),
                               #torchvision.transforms.RandomAffine(10),
                               #torchvision.transforms.Resize((28,28)),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size, shuffle=False)
classes = ('0', '1', '2', '3',
           '4', '5', '6', '7', '8', '9')


def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()



In [87]:

# class Hybrid(nn.Module):
#     """
#     block: A sub module
#     """

#     def __init__(self, inchannels=1, n_classes=10, patch_size=4, img_size=28, heads=4, dropout=0.1, expansion=4,
#                  f_drop=0.1, batch_size=1):
#         super(Hybrid, self).__init__()
#         self.batch_size = batch_size
#         self.patch_size = (img_size//patch_size)**2 + 1
#         self.img_size = img_size
#         self.embedding = img_size ** 2
#         self.relu = nn.ReLU(inplace=True)
#         self.conv1 = nn.Conv2d(inchannels, 2 * inchannels, kernel_size=5, padding=2, stride=1, dilation=1)
#         self.bn1 = nn.BatchNorm2d(2 * inchannels)
#         self.conv2 = nn.Conv2d(2 * inchannels, 2 * inchannels, kernel_size=3, padding=1, stride=1, dilation=1)
#         self.conv3 = nn.Conv2d(2 * inchannels, 4 * inchannels, kernel_size=1, padding=0, stride=1, dilation=1)
#         self.bn2 = nn.BatchNorm2d(4 * inchannels)
#         self.conv4 = nn.Conv2d(4 * inchannels, 4, kernel_size=1, padding=0, stride=1, dilation=1)
#         self.bn4 = nn.BatchNorm2d(4)
#         # reisdual
#         self.conv11 = nn.Conv2d(inchannels, 2, kernel_size=3, padding=1, dilation=1, stride=1)
#         self.bn11 = nn.BatchNorm2d(2)
#         # transformer
#         self.patchEmbedding = PatchEmbedding(in_channels=inchannels, patch_size=patch_size, emb_size=self.embedding,
#                                              img_size=img_size)
#         self.transformerEncoder = TransformerEncoderBlock(emb_size=self.embedding,
#                                                           drop_p=dropout,
#                                                           forward_expansion=expansion,
#                                                           forward_drop_p=f_drop)
        
#         self.m = nn.Sequential(
#                             #nn.Linear(50*28*28, 50*28*28),
#                             nn.Unflatten(2, (28, 28))
#                           )
#         self.unflatten = nn.Unflatten(1, (inchannels, img_size, img_size))
#         self.unflatten1 = nn.Unflatten(1, (img_size, img_size))

#         self.conv22 = nn.Conv2d(50, 10, kernel_size=3, padding=1, dilation=1, stride=1)


#         self.conv33 = nn.Conv2d(50, 8, kernel_size=3, padding=1, dilation=1, stride=1)

#         self.conv44 = nn.Conv2d(14, 8, kernel_size=3, padding=1, dilation=1, stride=1)

#         self.maxPool = torch.nn.MaxPool2d(2, 2)
#         self.dropout = torch.nn.Dropout(0.1, inplace=True)

#         self.full = torch.Tensor().cuda(0)
#         self.batch = torch.Tensor().cuda(0)

#         self.nn1 = torch.nn.Linear(8*196, 4*196)
#         self.nn11 = torch.nn.Linear(4*196, 10)
#         self.nn2 = torch.nn.Linear(2*196, 10)
#         self.nn3 = torch.nn.Linear(2*196, 10)


#         self.patchEmbedding1 = PatchEmbedding(in_channels=2*inchannels, patch_size=patch_size, emb_size=self.embedding,
#                                              img_size=img_size)
#         self.patchEmbedding2 = PatchEmbedding(in_channels=2*inchannels, patch_size=patch_size, emb_size=self.embedding,
#                                              img_size=img_size)
#         self.patchEmbedding3 = PatchEmbedding(in_channels=4*inchannels, patch_size=patch_size, emb_size=self.embedding,
#                                              img_size=img_size)
#         self.patchEmbedding4 = PatchEmbedding(in_channels=4*inchannels, patch_size=patch_size, emb_size=self.embedding,
#                                              img_size=img_size)

#     def forward(self, x):
#         y = x.clone
#         z = x.clone
#         x1 = self.bn1(self.conv1(x))
#         a1 = self.patchEmbedding1(x1)
#         a1 = self.transformerEncoder(a1)
#         a1 = self.m(a1)
#         a1 = self.conv22(a1)
#         x1 = self.bn1(self.conv2(x1))
#         a2 = self.patchEmbedding2(x1)
#         a2 = self.transformerEncoder(a2)
#         a2 = self.m(a2)
#         a2 = self.conv22(a2)
#         x1 = self.bn2(self.conv3(x1))
#         a3 = self.patchEmbedding3(x1)
#         a3 = self.transformerEncoder(a3)
#         a3 = self.m(a3)
#         a3 = self.conv22(a3)
#         x1 = self.bn4(self.conv4(x1))
#         a4 = self.patchEmbedding4(x1)
#         a4 = self.transformerEncoder(a4)
#         a4 = self.m(a4)
#         a4 = self.conv22(a4)
 
        
       
#         # residual
#         y = self.bn11(self.conv11(x))
#         # transformer

#         z = self.patchEmbedding(x)
#         z = self.transformerEncoder(z)
#         z = self.m(z)
#         z = self.conv22(z)

#         A = torch.cat((z, a1, a2, a3, a4), 1)
#         z = self.conv33(A)
        
#         x = torch.cat((x1, z, y), 1)
#         x = self.conv44(x)
#         x = self.maxPool(x)
#         # print('z',z.shape, 'y', y.shape, 'x1', x1.shape)
#         # x_ = torch.cat((x, x1),1)
#         # print('x_', x_.shape)
#         # x = x1 + x
#         # x = self.maxPool(x_)
#         # # x = self.maxPool(x)
#         # y = self.maxPool(y)
#         # z = self.maxPool(z) 

#         # # x_ = torch.mean(x_, 1)
#         # # y_ = torch.mean(y, 1)
#         # # z_ = torch.mean(z, 1)

        

#         x = torch.flatten(x, start_dim=1)
#         # y = torch.flatten(y, start_dim=1)
#         # z = torch.flatten(z, start_dim=1)
#         # #print('x', x.shape, 'y', y.shape, 'z', z.shape)
#         x = self.nn11(self.nn1(x))
#         # # y = self.nn2(y)
#         # # z = self.nn3(z)

#         return x,0,0

In [88]:
from torch.nn.modules.adaptive import AdaptiveLogSoftmaxWithLoss

class Hybrid(nn.Module):
    """
    block: A sub module
    """

    def __init__(self, inchannels=1, n_classes=10, patch_size=4, img_size=28, heads=4, dropout=0.1, expansion=4,
                 f_drop=0.1, batch_size=1):
        super(Hybrid, self).__init__()
        self.batch_size = batch_size
        self.patch_size = (img_size//patch_size)**2 + 1
        self.img_size = img_size
        self.embedding = img_size ** 2
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(inchannels, 2 * inchannels, kernel_size=5, padding=2, stride=1, dilation=1)
        self.bn1 = nn.BatchNorm2d(2 * inchannels)
        self.conv2 = nn.Conv2d(2 * inchannels, 2 * inchannels, kernel_size=3, padding=1, stride=1, dilation=1)
        self.conv3 = nn.Conv2d(2 * inchannels, 4 * inchannels, kernel_size=1, padding=0, stride=1, dilation=1)
        self.bn2 = nn.BatchNorm2d(4 * inchannels)
        self.conv4 = nn.Conv2d(4 * inchannels, 4, kernel_size=1, padding=0, stride=1, dilation=1)
        self.bn4 = nn.BatchNorm2d(4)
        # reisdual
        self.conv11 = nn.Conv2d(inchannels, 2, kernel_size=3, padding=1, dilation=1, stride=1)
        self.bn11 = nn.BatchNorm2d(2)
        # transformer
        self.patchEmbedding = PatchEmbedding(in_channels=inchannels, patch_size=patch_size, emb_size=self.embedding,
                                             img_size=img_size)
        self.transformerEncoder = TransformerEncoderBlock(emb_size=self.embedding,
                                                          drop_p=dropout,
                                                          forward_expansion=expansion,
                                                          forward_drop_p=f_drop)
        
        self.m = nn.Sequential(
                            #nn.Linear(50*28*28, 50*28*28),
                            nn.Unflatten(2, (28, 28))
                          )
        self.unflatten = nn.Unflatten(1, (inchannels, img_size, img_size))
        self.unflatten1 = nn.Unflatten(1, (img_size, img_size))

        self.conv22 = nn.Conv2d(50, 10, kernel_size=3, padding=1, dilation=1, stride=1)


        self.conv33 = nn.Conv2d(50, 8, kernel_size=3, padding=1, dilation=1, stride=1)

        self.conv44 = nn.Conv2d(16, 8, kernel_size=3, padding=1, dilation=1, stride=1)

        self.maxPool = torch.nn.MaxPool2d(2, 2)
        self.dropout = torch.nn.Dropout(0.1, inplace=True)

        self.full = torch.Tensor().cuda(0)
        self.batch = torch.Tensor().cuda(0)

        self.nn1 = torch.nn.Linear(8*49, 4*49)
        self.nn11 = torch.nn.Linear(4*49, 10)
        # self.nn2 = torch.nn.Linear(2*196, 10)
        # self.nn3 = torch.nn.Linear(2*196, 10)


        # self.patchEmbedding1 = PatchEmbedding(in_channels=2*inchannels, patch_size=patch_size, emb_size=self.embedding,
        #                                      img_size=img_size)
        # self.patchEmbedding2 = PatchEmbedding(in_channels=2*inchannels, patch_size=patch_size, emb_size=self.embedding,
        #                                      img_size=img_size)
        # self.patchEmbedding3 = PatchEmbedding(in_channels=4*inchannels, patch_size=patch_size, emb_size=self.embedding,
        #                                      img_size=img_size)
        # self.patchEmbedding4 = PatchEmbedding(in_channels=4*inchannels, patch_size=patch_size, emb_size=self.embedding,
        #                                      img_size=img_size)
        
        self.conv1_2 = nn.Conv2d(8, 2 * 8, kernel_size=5, padding=2, stride=1, dilation=1)
        self.bn1_2 = nn.BatchNorm2d(2 * 8)
        self.conv2_2 = nn.Conv2d(2 * 8, 2 * 8, kernel_size=3, padding=1, stride=1, dilation=1)
        self.conv3_2 = nn.Conv2d(2 * 8, 4 * 8, kernel_size=1, padding=0, stride=1, dilation=1)
        self.bn2_2 = nn.BatchNorm2d(4 * 8)
        self.conv4_2 = nn.Conv2d(4 * 8, 16, kernel_size=1, padding=0, stride=1, dilation=1)
        self.bn4_2 = nn.BatchNorm2d(16)

         # reisdual
        self.conv11_2 = nn.Conv2d(8, 8, kernel_size=3, padding=1, dilation=1, stride=1)
        self.bn11_2 = nn.BatchNorm2d(8)

        # transformer
        self.patchEmbedding_2 = PatchEmbedding(in_channels=8, patch_size=2, emb_size=196,
                                             img_size=14)
        self.transformerEncoder_2 = TransformerEncoderBlock(emb_size=196,
                                                          drop_p=dropout,
                                                          forward_expansion=expansion,
                                                          forward_drop_p=f_drop, num_heads=4)
        
        self.m_2 = nn.Sequential(
                            #nn.Linear(50*28*28, 50*28*28),
                            nn.Unflatten(2, (14, 14))
                          )
        self.conv22_2 = nn.Conv2d(50, 8, kernel_size=3, padding=1, dilation=1, stride=1)
        self.conv44_2 = nn.Conv2d(32, 16, kernel_size=3, padding=1, dilation=1, stride=1)
        self.conv44_3 = nn.Conv2d(16, 8, kernel_size=3, padding=1, dilation=1, stride=1)


    def forward(self, x):
        y = x.clone
        z = x.clone
        x1 = self.bn1(self.conv1(x))
        # a1 = self.patchEmbedding1(x1)
        # a1 = self.transformerEncoder(a1)
        # a1 = self.m(a1)
        # a1 = self.conv22(a1)
        x1 = self.bn1(self.conv2(x1))
        # a2 = self.patchEmbedding2(x1)
        # a2 = self.transformerEncoder(a2)
        # a2 = self.m(a2)
        # a2 = self.conv22(a2)
        x1 = self.bn2(self.conv3(x1))
        # a3 = self.patchEmbedding3(x1)
        # a3 = self.transformerEncoder(a3)
        # a3 = self.m(a3)
        # a3 = self.conv22(a3)
        x1 = self.bn4(self.conv4(x1))
        # a4 = self.patchEmbedding4(x1)
        # a4 = self.transformerEncoder(a4)
        # a4 = self.m(a4)
        # a4 = self.conv22(a4)
 
        
       
        # residual
        y = self.bn11(self.conv11(x))
        # transformer

        z = self.patchEmbedding(x)
        z = self.transformerEncoder(z)
        z = self.m(z)
        z = self.conv22(z)
      
        A = torch.cat((x1, z, y), 1)
        A = self.conv44(A)
        A = self.maxPool(A)
        
        #second block

        y_1 = self.bn11_2(self.conv11_2(A))

        z_1 = self.patchEmbedding_2(A)
    
        z_1 = self.transformerEncoder_2(z_1)
        z_1 = self.m_2(z_1)
        z_1 = self.conv22_2(z_1)

        A = self.bn1_2(self.conv1_2(A))
        # a1 = self.patchEmbedding1(x1)
        # a1 = self.transformerEncoder(a1)
        # a1 = self.m(a1)
        # a1 = self.conv22(a1)
        A = self.bn1_2(self.conv2_2(A))
        # a2 = self.patchEmbedding2(x1)
        # a2 = self.transformerEncoder(a2)
        # a2 = self.m(a2)
        # a2 = self.conv22(a2)
        A = self.bn2_2(self.conv3_2(A))
        # a3 = self.patchEmbedding3(x1)
        # a3 = self.transformerEncoder(a3)
        # a3 = self.m(a3)
        # a3 = self.conv22(a3)
        A = self.bn4_2(self.conv4_2(A))
        # a4 = self.patchEmbedding4(x1)
        # a4 = self.transformerEncoder(a4)
        # a4 = self.m(a4)
        # a4 = self.conv22(a4)'
        AA = torch.cat((A, z_1, y_1), 1)

        AA = self.conv44_2(AA)
        AA = self.conv44_3(AA)
        AA = self.maxPool(AA)


        AA = torch.flatten(AA, start_dim=1)
     
        x = self.nn11(self.nn1(AA))


        return x,0,0

In [89]:
model = Hybrid()
# model.load_state_dict(torch.load('/content/hybrid95.pth'))
# model.eval()

In [90]:

# model = Hybrid()
model.cuda()

criterion = nn.CrossEntropyLoss(reduction='sum')
criterion1 = nn.HuberLoss()
#optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam(model.parameters(), lr=0.0001)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)

In [None]:
torch.cuda.empty_cache()
start_acc = 0.0
for epoch in range(0,1000):  # loop over the dataset multiple times
    running_loss_hybrid = 0.0
    running_loss_ssl = 0.0

    for i, data in tqdm(enumerate(train_loader, 0)):

        if i % len(train_loader) == len(train_loader) - 1:    # print every 2000 mini-batches
       
            if epoch % 1 == 0 :
              
              correct = 0
              total = 0
              # since we're not training, we don't need to calculate the gradients for our outputs
              with torch.no_grad():
                  for data in test_loader:
                      images, labels = data
                      images, labels = images.cuda(), labels.cuda()
                      # calculate outputs by running images through the network
                      outputs, y, z = model(images)
                      _, predicted = torch.max(outputs, 1)

                      # the class with the highest energy is what we choose as prediction
                      _, predicted = torch.max(outputs.data, 1)
                      total += labels.size(0)
                      correct += (predicted == labels).sum().item()
              if start_acc < 100 *(float(correct) / total):
                start_acc = 100 *(float(correct) / total)
                torch.save(model.state_dict(), f'hybrid{epoch}.pth')

              print(f'epochs {epoch} Hybrid {float(running_loss_hybrid)/(len(train_loader))} SSL loss {float(running_loss_ssl)/(len(train_loader))} Accuracy: {100 *(float(correct) / total)} %')

            running_loss_hybrid = 0.0
            running_loss_ssl = 0.0


        inputs, labels = data
        inputs, labels = inputs.cuda(), labels.cuda()
        optimizer.zero_grad()
        X, Y, Z= model(inputs)
        loss1 = criterion(X, labels)
        loss1.backward(retain_graph = False)
        optimizer.step()
        running_loss_hybrid += loss1.item()

        # inputs, labels = data
        # inputs, labels = inputs.cuda(), labels.cuda()
        # optimizer.zero_grad()
        # X, Y, Z= model(inputs)
        #print('z sh',Z.shape, 'label shape', labels.shape)
        # loss2 = criterion1(Z, X)
        # loss2.backward()
        # optimizer.step()
        # running_loss_ssl += loss2.item()



              


768it [00:43, 18.54it/s]