In [6]:
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
import time
import PIL

In [4]:

import torch
import RETFound_MAE.models_vit as models_vit
from RETFound_MAE.util.pos_embed import interpolate_pos_embed
from timm.models.layers import trunc_normal_

# call the model
model = models_vit.__dict__['vit_large_patch16'](
    num_classes=2,
    drop_path_rate=0.2,
    global_pool=True,
)

# load RETFound weights
checkpoint = torch.load('./RETFound_MAE/RETFound_cfp_weights.pth', map_location='cpu')
checkpoint_model = checkpoint['model']
state_dict = model.state_dict()
for k in ['head.weight', 'head.bias']:
    if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
        print(f"Removing key {k} from pretrained checkpoint")
        del checkpoint_model[k]

# interpolate position embedding
interpolate_pos_embed(model, checkpoint_model)

# load pre-trained model
msg = model.load_state_dict(checkpoint_model, strict=False)

assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'}

# manually initialize fc layer
trunc_normal_(model.head.weight, std=2e-5)

print("Model = %s" % str(model))

Model = VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=1024, out_features=3072, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=1024, out_features=1024, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=1024, out_features=4096, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (n

In [5]:
#load images in RETFound_dataset and transform them
import torch
from torchvision import datasets, transforms

# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to 224x224
    transforms.ToTensor(),  # Convert image to PyTorch Tensor data type
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize images
])


In [8]:
example_image_path = 'databases/ODIR-5K/Testing Images/937_left.jpg'
#open image
example_image = PIL.Image.open(example_image_path)
#transform image
example_image = transform(example_image)
#add batch dimension
example_image = example_image.unsqueeze(0)
print(example_image.shape)
#run model
model.eval()
with torch.no_grad():
    output = model(example_image)

torch.Size([1, 3, 224, 224])


RuntimeError: Given normalized_shape=[1024], expected input with shape [*, 1024], but got input of size[1]

In [21]:

# Load images and apply transformations
train_set = datasets.ImageFolder('datasets/2023-12-27_15-42-07', transform=transform)
# Create a DataLoader
train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=2, shuffle=True)


In [22]:
#swap out the last layer
n_classes = len(train_set.classes)
model.head = torch.nn.Linear(in_features=model.head.in_features, out_features=n_classes, bias=True)


In [23]:
from tqdm import tqdm
import time

device = 'cpu'#torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
print(f'You are training on:{device}')
#use cross entropy loss
criterion = torch.nn.CrossEntropyLoss()
#use adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
epochs = 1
#train the model
for epoch in range(epochs):
    progress_bar = tqdm(train_dataloader, desc=f'Epoch {epoch + 1}/{epochs}')
    for images, labels in progress_bar:
        #transfer to gpu
        images, labels = images.to(device), labels.to(device)
        #zero gradients
        optimizer.zero_grad()
        # Forward pass
        outputs = model(images)
        #calculate loss
        loss = criterion(outputs, labels)
        # Backward pass
        loss.backward()
        # Update parameters
        optimizer.step()
        # Get predictions
        _, predicted = torch.max(outputs.data, 1)

        # Total number of labels
        total = labels.size(0)

        # Total correct predictions
        correct = (predicted == labels).sum().item()

        # Print accuracy
        progress_bar.set_postfix({'Loss': '{:.4f}'.format(loss.item()), 'Accuracy': '{:.2f}%'.format(100 * correct / total)})
#save model, change later to save best model
torch.save(model.state_dict(), f'RETFound_MAE/weights/{time.strftime("%Y%m%d_%H%M")}_RETFound_cfp_fine_tuned_ep1_weights.pth')


You are training on:cpu


Epoch 1/1:   1%|          | 10/1321 [00:35<1:18:28,  3.59s/it, Loss=1.6822, Accuracy=0.00%]


KeyboardInterrupt: 