# model

In [2]:
import torch
import torchvision
from torch import nn

from lightly.models import utils

from model import MAE
from tqdm.auto import tqdm, trange
import os
model_path = 'D:\google drive\MAE Bs PJ\state_dict\epoch 3-4200 weight.pt'

In [3]:
vit = torchvision.models.vit_b_16(weights=None)
pretrained_model = MAE(vit, 64, 80)
pretrained_model.load_state_dict(torch.load(model_path))

<All keys matched successfully>

In [12]:
class ViT(nn.Module):
    def __init__(self, pretrained_model, num_class):
        super(ViT, self).__init__()
        self.backbone = nn.Sequential(
                    pretrained_model.backbone,
                    nn.Linear(768, num_class),
                    nn.Softmax(dim=1)
                )
        self.out_dim = 80

    def forward(self, images):
        batch_size = images.shape[0]
        seq_length = images.shape[1]
        images = images.reshape((batch_size, seq_length, self.out_dim))
        x = self.backbone(images)
        return x

In [13]:
learning_rate = 1e-5
batch_size = 16
Epochs = 20
num_classes = len(os.listdir(r'D:\MaE fine tune\DataSet\train section'))

model = ViT(pretrained_model, num_class = num_classes)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
model = model.to('cuda')
model.train()

ViT(
  (backbone): Sequential(
    (0): MAEBackbone(
      (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (encoder): MAEEncoder(
        (dropout): Dropout(p=0.0, inplace=False)
        (layers): Sequential(
          (encoder_layer_0): EncoderBlock(
            (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
            (self_attention): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (dropout): Dropout(p=0.0, inplace=False)
            (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
            (mlp): MLPBlock(
              (0): Linear(in_features=768, out_features=3072, bias=True)
              (1): GELU(approximate='none')
              (2): Dropout(p=0.0, inplace=False)
              (3): Linear(in_features=3072, out_features=768, bias=True)
              (4): Dropout(p=0.0, inplace=False)
            )
          )
          (

# data

In [5]:
CHECKPOINT_PATH = 'checkpoint.tar'

checkpoint = torch.load(CHECKPOINT_PATH)
step = checkpoint['step']
mean = checkpoint['mean']
std = checkpoint['std']
total_loss = 0

In [6]:
from utils import normalize_data, Node_Dataset, get_time
from torch.utils.data import DataLoader, Dataset

data_folder = r'DataSet\train section'
test_folder = r'DataSet\test section'

dataset = Node_Dataset(data_folder, mean, std)
trainloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

test_dataset = Node_Dataset(test_folder, mean, std)
testloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

# WandB

In [7]:
import wandb


wandb.init(
    project='MaE fine tune',
    name=f'batch_size: {batch_size}-2',
    config={
        "learning_rate": learning_rate,
        "architecture": "ViT",
        "dataset": "customer",
        "epochs": Epochs,
        "mean": mean,
        "std": std})
config = wandb.config

wandb.watch(model)
epoch = 0
step = 0
total_loss = 0
total_accuracy = 0
model_savepath = f'D:\MaE fine tune\state_dict'

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


# training

In [8]:
from utils import video_show, train_step, Evaluate, log

In [9]:
for epoch in range(Epochs):
    for i, (batch, target) in tqdm(enumerate(trainloader), total=len(trainloader)):

        loss, accuracy = train_step(model,
                    criterion,
                    batch,
                    target,
                    optimizer,
                    num_classes=num_classes)
        total_loss += loss
        total_accuracy += accuracy
        with torch.no_grad():
            evaluative_loss, evaluative_accuracy = Evaluate(model, criterion, testloader, num_classes)

        wandb.log({"accuracy": accuracy,"evaluative_accuracy": evaluative_accuracy})
    total_accuracy = total_accuracy//(i+1)
    torch.save(model.state_dict(),os.path.join(model_savepath, f'2epoch-{epoch+1}weight.pt'))
    print(f'train_loss": {total_loss},"accuracy": {total_accuracy}%|"test_loss": {evaluative_loss},"evaluative_accuracy": {evaluative_accuracy}%')
    total_loss = 0
    total_accuracy = 0
    print(f'epoch: {epoch:>02}, time: {get_time()}')

  0%|          | 0/18 [00:00<?, ?it/s]

train_loss": 48.56214904785156,"accuracy": 10.0%|"test_loss": 2.661360263824463,"evaluative_accuracy": 13.33%
epoch: 00, time: 8:35:19


  0%|          | 0/18 [00:00<?, ?it/s]

train_loss": 47.526123046875,"accuracy": 17.0%|"test_loss": 2.546891927719116,"evaluative_accuracy": 33.33%
epoch: 01, time: 8:35:40


  0%|          | 0/18 [00:00<?, ?it/s]

train_loss": 44.739131927490234,"accuracy": 37.0%|"test_loss": 2.511406183242798,"evaluative_accuracy": 26.67%
epoch: 02, time: 8:36:1


  0%|          | 0/18 [00:00<?, ?it/s]

train_loss": 43.4141960144043,"accuracy": 45.0%|"test_loss": 2.4380621910095215,"evaluative_accuracy": 40.0%
epoch: 03, time: 8:36:22


  0%|          | 0/18 [00:00<?, ?it/s]

train_loss": 42.34962844848633,"accuracy": 49.0%|"test_loss": 2.4108965396881104,"evaluative_accuracy": 40.0%
epoch: 04, time: 8:36:43


  0%|          | 0/18 [00:00<?, ?it/s]

train_loss": 42.19560241699219,"accuracy": 51.0%|"test_loss": 2.323470115661621,"evaluative_accuracy": 53.33%
epoch: 05, time: 8:37:4


  0%|          | 0/18 [00:00<?, ?it/s]

train_loss": 40.883453369140625,"accuracy": 58.0%|"test_loss": 2.3269786834716797,"evaluative_accuracy": 53.33%
epoch: 06, time: 8:37:39


  0%|          | 0/18 [00:00<?, ?it/s]

train_loss": 39.62763214111328,"accuracy": 63.0%|"test_loss": 2.2648561000823975,"evaluative_accuracy": 60.0%
epoch: 07, time: 8:38:0


  0%|          | 0/18 [00:00<?, ?it/s]

train_loss": 39.05430603027344,"accuracy": 69.0%|"test_loss": 2.172774314880371,"evaluative_accuracy": 66.67%
epoch: 08, time: 8:38:21


  0%|          | 0/18 [00:00<?, ?it/s]

train_loss": 38.248050689697266,"accuracy": 71.0%|"test_loss": 2.1071667671203613,"evaluative_accuracy": 73.33%
epoch: 09, time: 8:38:42


  0%|          | 0/18 [00:00<?, ?it/s]

train_loss": 38.38497543334961,"accuracy": 70.0%|"test_loss": 2.1434545516967773,"evaluative_accuracy": 73.33%
epoch: 10, time: 8:39:4


  0%|          | 0/18 [00:00<?, ?it/s]

train_loss": 38.09982681274414,"accuracy": 71.0%|"test_loss": 2.091933012008667,"evaluative_accuracy": 73.33%
epoch: 11, time: 8:39:25


  0%|          | 0/18 [00:00<?, ?it/s]

train_loss": 37.46954345703125,"accuracy": 76.0%|"test_loss": 2.106869697570801,"evaluative_accuracy": 73.33%
epoch: 12, time: 8:39:46


  0%|          | 0/18 [00:00<?, ?it/s]

train_loss": 38.90101623535156,"accuracy": 67.0%|"test_loss": 2.1690118312835693,"evaluative_accuracy": 73.33%
epoch: 13, time: 8:40:7


  0%|          | 0/18 [00:00<?, ?it/s]

train_loss": 37.56167984008789,"accuracy": 74.0%|"test_loss": 2.0912442207336426,"evaluative_accuracy": 73.33%
epoch: 14, time: 8:40:28


  0%|          | 0/18 [00:00<?, ?it/s]

train_loss": 37.34629821777344,"accuracy": 76.0%|"test_loss": 2.077385902404785,"evaluative_accuracy": 80.0%
epoch: 15, time: 8:40:50


  0%|          | 0/18 [00:00<?, ?it/s]

train_loss": 37.54411315917969,"accuracy": 73.0%|"test_loss": 2.113452196121216,"evaluative_accuracy": 73.33%
epoch: 16, time: 8:41:26


  0%|          | 0/18 [00:00<?, ?it/s]

train_loss": 37.57820510864258,"accuracy": 73.0%|"test_loss": 2.046360492706299,"evaluative_accuracy": 80.0%
epoch: 17, time: 8:42:23


  0%|          | 0/18 [00:00<?, ?it/s]

train_loss": 36.50830078125,"accuracy": 79.0%|"test_loss": 2.0781543254852295,"evaluative_accuracy": 73.33%
epoch: 18, time: 8:42:52


  0%|          | 0/18 [00:00<?, ?it/s]

train_loss": 36.48641586303711,"accuracy": 79.0%|"test_loss": 2.083428382873535,"evaluative_accuracy": 73.33%
epoch: 19, time: 8:43:13
