In [1]:
from base_vit import ViT
import torch
from lora import LoRA_ViT
from seg_vit import SegWrapForViT

In [2]:
img = torch.randn(2, 3, 384, 384)

In [3]:
model = ViT('B_16_imagenet1k')
model.load_state_dict(torch.load('B_16_imagenet1k.pth'))
lora_model = LoRA_ViT(model, r=4)

In [4]:
seg_vit = SegWrapForViT(vit_model=lora_model, image_size=384,
                        patches=16, dim=768, n_classes=10)



In [5]:
mask = seg_vit(img)
print(mask.shape)

torch.Size([2, 10, 384, 384])


In [6]:
num_params = sum(p.numel() for p in seg_vit.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {num_params/2**20:.3f}")

Number of trainable parameters: 6.459


In [7]:
lora_model.save_lora_parameters('feb27.safetensors')

In [8]:
lora_model.load_lora_parameters('feb27.safetensors')

In [None]:
# loop over the dataset multiple times
for epoch in range(5):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print('Loss: {}'.format(running_loss)

print('Finished Training')

In [7]:
import timm 
import torch

In [12]:
all_vit_models = timm.list_models('*vit*', pretrained=True)
all_vit_models

['convit_base',
 'convit_small',
 'convit_tiny',
 'crossvit_9_240',
 'crossvit_9_dagger_240',
 'crossvit_15_240',
 'crossvit_15_dagger_240',
 'crossvit_15_dagger_408',
 'crossvit_18_240',
 'crossvit_18_dagger_240',
 'crossvit_18_dagger_408',
 'crossvit_base_240',
 'crossvit_small_240',
 'crossvit_tiny_240',
 'gcvit_base',
 'gcvit_small',
 'gcvit_tiny',
 'gcvit_xtiny',
 'gcvit_xxtiny',
 'levit_128',
 'levit_128s',
 'levit_192',
 'levit_256',
 'levit_384',
 'maxvit_nano_rw_256',
 'maxvit_rmlp_nano_rw_256',
 'maxvit_rmlp_pico_rw_256',
 'maxvit_rmlp_small_rw_224',
 'maxvit_rmlp_tiny_rw_256',
 'maxvit_tiny_rw_224',
 'maxxvit_rmlp_nano_rw_256',
 'maxxvit_rmlp_small_rw_256',
 'mobilevit_s',
 'mobilevit_xs',
 'mobilevit_xxs',
 'mobilevitv2_050',
 'mobilevitv2_075',
 'mobilevitv2_100',
 'mobilevitv2_125',
 'mobilevitv2_150',
 'mobilevitv2_150_384_in22ft1k',
 'mobilevitv2_150_in22ft1k',
 'mobilevitv2_175',
 'mobilevitv2_175_384_in22ft1k',
 'mobilevitv2_175_in22ft1k',
 'mobilevitv2_200',
 'mobile