In [55]:
import os
import torch
import clip
from PIL import Image

from  torch.utils.data  import Dataset, DataLoader
import numpy as np
from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader
from tqdm import tqdm
from collections import OrderedDict

In [16]:
class PCam200(Dataset):
    def __init__(self, dataset_dir, transform=None):
        ext = ['.JPG', '.jpg', '.JPEG', '.png']
        self.transform = transform
        self.root_dir = dataset_dir
        self.normals = sorted( os.listdir(os.path.join(self.root_dir, 'normal')) )
        self.tumors = sorted( os.listdir(os.path.join(self.root_dir, 'tumor')) )
        try: 
            self.dirs.remove('thumbnail_position_map')
        except:
            print()
        self.data = []
        self.label = []
        idx = 0
        for d in self.normals:
            slide_dir = os.path.join(self.root_dir, 'normal', d)
            if os.path.isdir(slide_dir):
                imgs = sorted( os.listdir(slide_dir) )
                for img in imgs:
                    if img[img.find('.'): ] in ext:
                        self.data.append(os.path.join(slide_dir, img))
                        self.label.append(0)
                        idx += 1
        for d in self.tumors:
            slide_dir = os.path.join(self.root_dir, 'tumor', d)
            if os.path.isdir(slide_dir):
                imgs = sorted( os.listdir(slide_dir) )
                for img in imgs:
                    if img[img.find('.'): ] in ext:
                        self.data.append(os.path.join(slide_dir, img))
                        self.label.append(1)
                        idx += 1
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        with open(self.data[idx] , 'rb') as f:
            img = Image.open(f).convert('RGB')
        return self.transform(img), torch.tensor(self.label[idx], dtype=torch.int64)


In [63]:
# load model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device=device)


In [38]:
# zero-shot prediction
text_inputs = torch.cat([clip.tokenize(f"a photo of lymph node {c} tissue") for c in ['tumor', 'normal']]).to(device) 
for img in os.listdir("./images"):
    print(img)
    image_input = preprocess(Image.open(os.path.join(cur_dir, img))).unsqueeze(0).to(device) 
    with torch.no_grad():
        image_features = model.encode_image(image_input)
        text_features = model.encode_text(text_inputs)

    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
    print("  Tumor : {:.2f}".format( similarity[0][0].item() ))
    print("  Normal: {:.2f}".format( similarity[0][1].item() ))


normal021_0000026.jpg
  Tumor : 0.79
  Normal: 0.21
normal054_0000014.jpg
  Tumor : 0.74
  Normal: 0.26
tumor007_0000044.jpg
  Tumor : 0.69
  Normal: 0.31
tumor016_0000038.jpg
  Tumor : 0.71
  Normal: 0.29


In [49]:
# linear probe
def get_features(dataset_dir, transform=None):
    dataset = PCam200(dataset_dir, transform=transform)
    all_features = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(DataLoader(dataset, batch_size=100)):
            features = model.encode_image(images.to(device))

            all_features.append(features)
            all_labels.append(labels)

    return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy()

# Calculate the image features
train_features, train_labels = get_features("./train", preprocess)
test_features, test_labels = get_features("./test", preprocess)

# Perform logistic regression
classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1)
classifier.fit(train_features, train_labels)

# Evaluate using the logistic regression classifier
predictions = classifier.predict(test_features)
accuracy = np.mean((test_labels == predictions).astype(np.float)) * 100.
print(f"Accuracy = {accuracy:.3f}")

  0%|                                                                                          | 0/286 [00:00<?, ?it/s]




100%|████████████████████████████████████████████████████████████████████████████████| 286/286 [05:59<00:00,  1.26s/it]
  0%|                                                                                          | 0/177 [00:00<?, ?it/s]




100%|████████████████████████████████████████████████████████████████████████████████| 177/177 [03:41<00:00,  1.25s/it]
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.


Accuracy = 82.239


[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    4.2s finished


In [None]:
# train epochs
epochs = 5
# batch size, reduce batch_size if out of memory error occurs
batch_size = 512

# Linear head transfer learning
new_model = torch.nn.Sequential( OrderedDict([
                                   
                                      ('clip',  model.visual),
                                      ('head',  torch.nn.Linear(model.visual.proj.size(1), 2, bias=True))
                                  ])
                               )
# initialize bias of added linear layer as zero
torch.nn.init.zeros_( new_model.head.bias.data )

# cast dtypes weight and bias of added linear layer to model dtype
new_model.head.weight.data = new_model[-1].weight.data.to(model.dtype)
new_model.head.bias.data   = new_model[-1].bias.data.to(model.dtype)
new_model.to(device)
    
# prepare dataset and dataloader
train_set = PCam200("K:\\Pcam200_new\\train", transform=preprocess)
test_set  = PCam200("K:\\Pcam200_new\\test", transform=preprocess)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(test_set,  batch_size=batch_size*2, shuffle=False)

# define optimizer and loss function
optimizer = torch.optim.SGD([{"params": new_model[-1].parameters(), "lr": 1e-3},  
                             {"params": new_model[0].parameters(), "lr": 1e-4}],  # smaller learning rate for backbone model
                              momentum=0.9)
criterion = torch.nn.CrossEntropyLoss()

train_iterator = tqdm( train_loader )
test_iterator  = tqdm( test_loader  )
# training
for epoch in range(epochs):
    for batch in train_iterator:
        img, label = batch[0].to(model.dtype).to(device), batch[1].to(device)
        new_model.train()
        
        out = new_model(img)
        loss = criterion(out, label)
        
        optimizer.zero_grad()
        loss.backward()      # calculate gradients
        optimizer.step()     # update parameters
        
    # testing
    total = len(test_set)
    correct = 0.0
    for batch in test_iterator:
        img, label = batch[0].to(model.dtype).to(device), batch[1].to(device)
        new_model.eval()

        with torch.no_grad():
            out = new_model(img)

            _, predicted = torch.max(out, 1)
            correct += (predicted == label).sum().item()

    # Evaluate accuracy on test set
    accuracy = (correct/total) * 100.
    print(f"{epoch} epochs of fine-tuning yields,")
    print(f"Accuracy = {accuracy:.3f}")