In [8]:
!pip install tdqm

Collecting tdqm
  Downloading tdqm-0.0.1.tar.gz (1.4 kB)
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting tqdm
  Using cached tqdm-4.65.0-py3-none-any.whl (77 kB)
Building wheels for collected packages: tdqm
  Building wheel for tdqm (setup.py) ... [?25ldone
[?25h  Created wheel for tdqm: filename=tdqm-0.0.1-py3-none-any.whl size=1323 sha256=60a743020659d02088351b7dd29fb601d5dae71a2f65de35aabfbbe1f02502b9
  Stored in directory: /Users/adnanoomerjee/Library/Caches/pip/wheels/ff/ec/31/6c16e9c6cf6c186a1b0e48a63a0067dee1be20d92f073b861b
Successfully built tdqm
Installing collected packages: tqdm, tdqm
Successfully installed tdqm-0.0.1 tqdm-4.65.0


In [27]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch
from torch.utils.data import TensorDataset, ConcatDataset, random_split

# Define the root directory where the dataset should be stored
root = ''

# Load the dataset using the OxfordIIITPet class with download=True
# do not apply transforms here as they affect the loading of the targets
train_val_data = datasets.OxfordIIITPet(root=root, split='trainval', 
                                        target_types=['category','segmentation'], download=True)
test_data = datasets.OxfordIIITPet(root=root, split='test',
                              target_types=['category','segmentation'], download=True)

dataset = ConcatDataset([train_val_data, test_data])

# Define the transform to apply to images
img_transform = transforms.Compose([transforms.Resize((64, 64)),  # resize the images to 224x224 pixels
                                    transforms.ToTensor()  # convert the images to tensors, apply scaling (from 0-255 to 0-1)
                                   ])

# Define the transform to apply to masks
mask_transform = transforms.Compose([transforms.Resize((64, 64)),  # resize the images to 224x224 pixels
                                     transforms.PILToTensor(),       # convert to tensor, do not apply scaling
                                     transforms.Lambda(lambda x: x -1) # remove 1 since pixel classes are 1-indexed
                                    ])

# loop through all images, apply transforms and store in lists
# cannot directly apply transforms due to (class, mask) tuple in original dataset
all_img = []
all_mask = []
all_label = []

for i, datapoint in enumerate(dataset):
    img, targets = datapoint
    class_label, mask = targets
    
    # apply transforms to image
    img = img_transform(img)
    all_img.append(img)
    # apply transforms to mask
    mask = mask_transform(mask)
    all_mask.append(mask)
    # apply transforms to label
    all_label.append(class_label)
    
# create new dataset
dataset = TensorDataset(torch.stack(all_img),torch.stack(all_mask),torch.tensor(all_label))

# create train, val, test splits (70%,10%,20%)
len_train = int(0.7 * len(dataset))
len_val = int(0.1 * len(dataset))
len_test = len(dataset) - len_train - len_val
trainset, valset, testset = random_split(dataset, [len_train, len_val, len_test], generator=torch.Generator().manual_seed(42))


In [34]:
# create unlabelled set
unlabelled_set = datasets.OxfordIIITPet(root=root, split='trainval', target_types=['segmentation'], download=True)
unlabelled_set.transform = img_transform
print(len(train_val_data))
print(len(test_data))
print(len(unlabelled_set))

3680
3669
3680


In [29]:
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import TensorDataset, DataLoader, random_split
from torch import nn, optim
from tqdm import tqdm
import torch.nn.functional as F
import torchvision.models as models


In [30]:
device = torch.device('mps')

In [31]:
if torch.cuda.is_available():
    print('Colab is running on GPU!')
else:
    print('Colab is running on CPU')

Colab is running on CPU


In [32]:
class PretrainedUNet(nn.Module):
    def __init__(self, num_classes=1):
        super().__init__()
        self.encoder = models.resnet18(pretrained=True)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, num_classes, kernel_size=2, stride=2)
        )

    def forward(self, x):
        x = self.encoder.conv1(x)
        x = self.encoder.bn1(x)
        x = self.encoder.relu(x)
        x = self.encoder.maxpool(x)

        x = self.encoder.layer1(x)
        x = self.encoder.layer2(x)
        x = self.encoder.layer3(x)
        x = self.encoder.layer4(x)

        x = self.decoder(x)

        return x

class DiceLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, y_pred, y_true):
        smooth = 1.0
        y_true = F.interpolate(y_true, size=y_pred.size()[2:], mode='nearest')
        intersection = (y_pred * y_true).sum(dim=[2, 3])
        union = y_pred.sum(dim=[2, 3]) + y_true.sum(dim=[2, 3])
        dice = (2 * intersection + smooth) / (union + smooth)
        loss = 1 - dice.mean()
        return loss
    
def train(trainset, valset, unlabelled_set, model, batch_size=16, epochs=100, pred_threshold=0.5, loss_threshold=0.001, patience=10):
    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    valloader = DataLoader(valset, batch_size=batch_size, num_workers=2, pin_memory=True)
    unlabelled_loader = DataLoader(unlabelled_set, batch_size=batch_size, num_workers=2, pin_memory=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = DiceLoss()

    best_val_loss = float('inf')
    early_stop_counter = 0
    for epoch in range(epochs):
        print(f'Epoch {epoch+1}')
        model.train()
        train_loss = 0.0
        for images, masks, labels in tqdm(trainloader):
            images, masks, labels = images.to(device), masks.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * images.size(0)
        train_loss /= len(trainset)

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for images, masks, labels in tqdm(valloader):
                images, masks, labels = images.to(device), masks.to(device), labels.to(device)
                outputs = model(images)
                print(outputs)
                loss = criterion(outputs, masks)
                val_loss += loss.item() * images.size(0)
            val_loss /= len(valset)

        if val_loss < best_val_loss - loss_threshold:
            best_val_loss = val_loss
            early_stop_counter = 0
            torch.save(model.state_dict(), 'model.pt')
            print(f'Saved model at epoch {epoch+1}')
        else:
            early_stop_counter += 1
            if early_stop_counter >= patience:
                print(f'Early stopping after {epoch+1} epochs')
                return model

        print(f'Train loss: {train_loss:.4f} | Val loss: {val_loss:.4f}')

        model.eval()
        with torch.no_grad():
            for images, masks, labels in tqdm(unlabelled_loader):
                images, masks, labels = images.to(device), masks.to(device), labels.to(device)
                outputs = model(images)
                preds = (outputs > pred_threshold).float()
                labelled_images = images[preds.squeeze() == 1]
                labelled_masks = masks[preds.squeeze() == 1]
                labelled_labels = torch.ones(labelled_images.size(0), dtype=torch.long)
                if labelled_images.size(0) > 0:
                    trainset += TensorDataset(labelled_images, labelled_masks, labelled_labels)
                    unlabelled_set = TensorDataset(images[preds.squeeze() == 0], masks[preds.squeeze() == 0], torch.zeros((images.size(0)-labelled_images.size(0)), dtype=torch.long))
                    unlabelled_loader = DataLoader(unlabelled_set, batch_size=batch_size, num_workers=2, pin_memory=True)
                else:
                    print('No new labelled images found')

    return model


In [33]:
test = train(trainset, valset, unlabelled_set, model=PretrainedUNet(), batch_size=16, epochs=100, pred_threshold=0.5, loss_threshold=0.001, patience=10)
# test = train(trainset, valset, unlabelled_set, model=Res_U_Net(), batch_size=16, epochs=100, pred_threshold=0.5, loss_threshold=0.001, patience=10)



Epoch 1


 24%|██▍       | 77/322 [00:33<01:46,  2.31it/s]


KeyboardInterrupt: 