In [1]:
import os
import torch
import yaml
from torchvision import datasets
from torch.utils.data import DataLoader, Dataset
from data.multi_view_data_injector import MultiViewDataInjector
from data.transforms import get_simclr_data_transforms
from models.mlp_head import MLPHead
from models.resnet_base_network import ResNet18
from trainer import BYOLTrainer
from PIL import Image
from pathlib import Path

In [2]:
torch.manual_seed(0)
config = yaml.load(open("./config/config.yaml", "r"), Loader=yaml.FullLoader)

In [3]:
class ImagesDataset(Dataset):
    def __init__(self, transform=None):
        """
        Args:
            train (boolean): train or test
        """
        self.paths = []
        self.labels = []
        for path in Path('kneeKL224/').glob('**/*'):
            folder, file = os.path.split(path)
            _, ext = os.path.splitext(path)
            if ext.lower() == '.png':
                self.paths.append(path)
                self.labels.append(folder[-1])
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        label = self.labels[index]
        img = Image.open(path)

        if self.transform:
            img = self.transform(img)

        return img

In [4]:
data_transform = get_simclr_data_transforms(**config['data_transforms'])
train_dataset = ImagesDataset(transform=MultiViewDataInjector([data_transform, data_transform]))

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Training with: {device}")

# online network
online_network = ResNet18(**config['network']).to(device)
pretrained_folder = config['network']['fine_tune_from']

# load pre-trained model if defined
if pretrained_folder:
    try:
        checkpoints_folder = os.path.join('./runs', pretrained_folder, 'checkpoints')

        # load pre-trained parameters
        load_params = torch.load(os.path.join(os.path.join(checkpoints_folder, 'model.pth')),
                                  map_location=torch.device(torch.device(device)))

        online_network.load_state_dict(load_params['online_network_state_dict'])

        print(f"Retraining from checkpoint folder {config['network']['fine_tune_from']}")

    except FileNotFoundError:
        print("Pre-trained weights not found. Training from scratch.")

# predictor network
predictor = MLPHead(in_channels=online_network.projetion.net[-1].out_features,
                    **config['network']['projection_head']).to(device)

# target encoder
target_network = ResNet18(**config['network']).to(device)

optimizer = torch.optim.SGD(list(online_network.parameters()) + list(predictor.parameters()),
                            **config['optimizer']['params'])

trainer = BYOLTrainer(online_network=online_network,
                      target_network=target_network,
                      optimizer=optimizer,
                      predictor=predictor,
                      device=device,
                      **config['trainer'])

trainer.train(train_dataset)

Training with: cuda
Retraining from checkpoint folder Aug13_14-14-58_DESKTOP-SSCRTA9
End of epoch 0
End of epoch 1
End of epoch 2
End of epoch 3
End of epoch 4
End of epoch 5
End of epoch 6
End of epoch 7
End of epoch 8
End of epoch 9
End of epoch 10
End of epoch 11
End of epoch 12
End of epoch 13
End of epoch 14
End of epoch 15
End of epoch 16
End of epoch 17
End of epoch 18
End of epoch 19
End of epoch 20
End of epoch 21
End of epoch 22
End of epoch 23
End of epoch 24
End of epoch 25
End of epoch 26
End of epoch 27
End of epoch 28
End of epoch 29
End of epoch 30
End of epoch 31
End of epoch 32
End of epoch 33
End of epoch 34
End of epoch 35
End of epoch 36
End of epoch 37
End of epoch 38
End of epoch 39
End of epoch 40
End of epoch 41
End of epoch 42
End of epoch 43
End of epoch 44
End of epoch 45
End of epoch 46
End of epoch 47
End of epoch 48
End of epoch 49
End of epoch 50
End of epoch 51
End of epoch 52
End of epoch 53
End of epoch 54
End of epoch 55
End of epoch 56
End of epoch 