<a href="https://colab.research.google.com/github/jdasam/ant6040-2022/blob/main/assignment_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Assignment 1. Maps of Images
- The goal of this assignment is to understand the embeddings of deep-learning-based image classification model
  - By visualizing the image embedding into 2-dim space, you can see how the model clusters the images

- If you run this notebook on Colab, it is highly recommended to use GPU, to calculate the embeddings faster

.
- Evaluation Criteria
  - The assignment expects you to achieve either of these two goals:
    - a. Discover and explain what kind of visual characteristics does the model uses as important features
      - Your explanation can be just your own hypothesis. But please try to find an example or evidence to support your hypothesis.
      - Note that the pre-trained model was trained with ImageNet 1K
        - [List of Labels in ImageNet 1k](https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a)
        - [Sample Images](https://github.com/EliSchwartz/imagenet-sample-images)
    - b. Find interesting examples on how the images distribute in the embedding space
      - Explain why those examples were interesting to you.

.

- Select your own images to complete this assignment
  - Selecting an interesting image set would be the important part of the assignment
  - You can download existing dataset or crawl images from web using your own keywords
  - Or you can manually import dataset of your selection

.



## Preparation: Install and import libraries
- Running the cells below will automatically install and import libraries you need

In [None]:
!pip install umap-learn # install umap
# install jmd_imagescraper
!git clone https://github.com/jdasam/jmd_imagescraper.git
%cd jmd_imagescraper
!pip install .
%cd ..
!git clone https://github.com/Quasimondo/RasterFairy.git
%cd RasterFairy
!pip install .
%cd ..

In [None]:
'''
You don't have to change codes in this cell
'''

from jmd_imagescraper.core import duckduckgo_search
from pathlib import Path
import random
import shutil
import os
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from tqdm.auto import tqdm
import umap
import matplotlib.pyplot as plt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
import rasterfairy
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import numpy as np

def get_image_files(path_dir):
  if isinstance(path_dir, str):
    path_dir = Path(path_dir)
  return list(path_dir.rglob('*.jpg'))

## 0. Collect your image
- There are two options:
  - Option A: Use Artbench (a dataset used during the class)
    - Or you might download other image dataset from web
  - Option B: Download images using web crawler. You can select search keywords of your own.

### Option A. Use ArtBench (or other dataset)
- You can use ArtBench dataset, which consists 60,000 paintings of of 10 different artistic styles.
- If you have want to use other datasets, you can also use them.


In [None]:
'''
Warnings: This cell will download a ArtBench dataset images.
'''
!wget https://artbench.eecs.berkeley.edu/files/artbench-10-imagefolder-split.tar
!tar -xvf artbench-10-imagefolder-split.tar

### Option b. Crawl Images by Keywords
- The code below will automatically downloads images that are searched by `image_keywords` in subdirectory of `images/`
  - You can write down your own keywords in the `image_keywords` list.
  - You can set number of images per each keyword with `NUM_IMG`

In [None]:

'''
TODO: Write your own image_types list to search for different types of images
'''
image_keywords = ["French Art", "German Art", "British Art", "Russian Art"]
NUM_IMG = 50

assert all(isinstance(typ, str) for typ in image_keywords), "Every element of image_types has to be string"

'''
Warning: Running this cell will delete the content of the folder "images/", and download new images
'''
!rm -rf images # Delete the previous images automatically
img_dir = Path('images')
# train_dir = img_dir / 'train'
# valid_dir = img_dir / 'valid'
for typ in image_keywords:
  duckduckgo_search(img_dir, typ, typ, max_results=NUM_IMG)
  typ_dir = img_dir / typ


# Code below will split the images into train and validation set
random.seed(0)
valid_indices = random.sample(range(NUM_IMG), NUM_IMG//5)
for typ in image_keywords:
  typ_dir = img_dir / typ
  if not typ_dir.exists():
    continue
  train_dir = img_dir / 'train' / typ
  test_dir = img_dir / 'test' / typ
  train_dir.mkdir(parents=True, exist_ok=True)
  test_dir.mkdir(parents=True, exist_ok=True)
  img_files = get_image_files(typ_dir)
  valid_imgs = [img_files[i] for i in valid_indices]
  for fn in valid_imgs:
    shutil.move(fn, test_dir/fn.name)
  img_files = get_image_files(typ_dir)
  for fn in img_files:
    shutil.move(fn, train_dir/fn.name)
  assert len(get_image_files(typ_dir)) == 0, f"There are still remaining files in {typ_dir}"
  os.rmdir(typ_dir)

### 0.2. Make Dataset
- The pre-defined `ImageSet` class will list-up the path for every images with `jpg` and `png` (in default)
  - The list of image paths is saved as `image_fns`
  - The label of each image is automatically set by it's parent directory's name


In [None]:
class ImageSet:
  def __init__(self, path_dir, file_types=['jpg', 'png'], transform=None):
    self.path = Path(path_dir)
    self.image_fns = sorted(item for y in [list(self.path.rglob(f'*.{x}')) for x in file_types] for item in y)
    self.classes = sorted(list(set([x.parent.name for x in self.image_fns])))
    self.cls2idx = {k: i for i, k in enumerate(self.classes)}
    self.transform = transforms.Compose([
                        transforms.Resize(256),
                        transforms.CenterCrop(224),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

  def __len__(self):
    return len(self.image_fns)

  def __getitem__(self, idx):
    img_path = self.image_fns[idx]
    img = self.transform(Image.open(img_path).convert('RGB'))
    # img = Image.open(img_path)
    cls = img_path.parent.name
    return img, self.cls2idx[cls]


'''
TODO: Select a directory path for your dataset
'''

# Option A: use ArtBench dataset (test split only in default)
# dataset = ImageSet('data/artbench-10-imagefolder-split/test')

# Option B: use crawled dataset
# You can use both train and test set, by selecting path as 'images/'
dataset = ImageSet('images/')
len(dataset)

In [None]:
'''
Show Batch Examples

Because we have used shuffle=True in the ImageSet class, the images will be shown in a random order.
'''

dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)
batch = next(iter(dataloader))

tensor2pil = transforms.Compose([
    transforms.Normalize(mean=[0, 0, 0], std=[4.3668, 4.4643, 4.4444]),
    transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1, 1, 1]),
    transforms.ToPILImage()
])

plt.figure(figsize=(10,10))

for i in range(16):
  plt.subplot(4, 4, i+1)
  plt.imshow(tensor2pil(batch[0][i]))
  plt.axis('off')
  plt.title(dataset.classes[batch[1][i]])

## 1. Load model and make embeddings of the images
- To make running faster, `resnet18` is selected as a default model
  - You can select other models if you want.
- To get the final embedding of the model, which is an input for the resnet18.fc, we register a forward hook

In [None]:
model = torchvision.models.resnet18(pretrained=True)

embedding = None
def hook(module, input, output):
  global embedding
  embedding = input
model.fc.register_forward_hook(hook)

In [None]:
class EmbeddingVisualizer:
  def __init__(self, dataset, model, emb_method='umap', device=None, batch_size=16):
    self.dataset = dataset
    self.model = model
    if device is None:
      self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    self.model.to(self.device)
    self.model.eval()
    self.batch_size = batch_size

    print("Calculating Embeddings...")
    self.embeddings = self.get_embedding()
    self.red_embs, self.reducer = self.get_reduced_embedding(emb_method)

    self.center_crop = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()])

  def __len__(self):
    return len(self.dataset)

  def get_embedding(self):
    global embedding
    dataloader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=False)
    embeddings = []

    embedding = None
    with torch.inference_mode():
      for batch in tqdm(dataloader):
        img, label = batch
        img = img.to(self.device)
        model(img)
        embeddings.append(embedding[0])
        embedding = None

    embeddings = torch.cat(embeddings, dim=0).cpu()
    return embeddings

  def get_reduced_embedding(self, method='umap', n_components=2):
    if method == 'umap':
      reducer = umap.UMAP(n_components=n_components)
      embedding = reducer.fit_transform(self.embeddings.cpu().numpy())
    elif method == 'tsne':
      reducer = TSNE(n_components=n_components)
    elif method == 'pca':
      reducer = PCA(n_components=n_components)
    else:
      raise NotImplementedError(f"{method} is not implemented")

    return reducer.fit_transform(self.embeddings.cpu().numpy()), reducer

  def get_image(self, path, zoom=0.2):
    img = Image.open(path)
    return OffsetImage(self.center_crop(img).permute(1,2,0).numpy(), zoom=zoom)

  def plot_images_on_embedding_path(self, embeddings, input_a, input_b, k=10, add_paths=None):
    if isinstance(input_a, int) and isinstance(input_b, int):
      emb_a = embeddings[input_a]
      emb_b = embeddings[input_b]
    elif isinstance(input_a, torch.Tensor) and isinstance(input_b, torch.Tensor):
      emb_a = input_a
      emb_b = input_b
    else:
      raise ValueError("input_a and input_b must be either int or torch.Tensor")

    mid_embs = []
    for i in range(0, k+1):
      t = i / k
      emb = emb_a * (1-t) + emb_b * t
      mid_embs.append(emb)
    mid_embs = torch.stack(mid_embs, dim=0)
    dists = torch.cdist(mid_embs, embeddings)
    nearest_idx = dists.argsort(dim=1)[:, 0]
    nearest_paths = [self.dataset.image_fns[idx] for idx in nearest_idx]
    if isinstance(add_paths, list) or isinstance(add_paths, tuple):
      nearest_paths = add_paths[0:1] + nearest_paths + add_paths[1:]
    images = [Image.open(path) for path in nearest_paths]

    fig, ax = plt.subplots(1, len(nearest_paths), figsize=(30, 5))
    for i, image in enumerate(images):
      plt.subplot(1,  len(nearest_paths), i+1)
      plt.imshow(image)
      plt.axis('off')


  def plot_embeddings_with_image(self, embeddings, num, figsize=(20, 20), zoom=0.2, rand_seed=0):
    if num < len(embeddings):
      np.random.seed(rand_seed)
      idxs = np.random.choice(len(embeddings), num, replace=False)
    else:
      idxs = np.arange(len(embeddings))
    paths = [str(self.dataset.image_fns[i]) for i in idxs]

    x = embeddings[idxs, 0]
    y = embeddings[idxs, 1]

    fig, ax = plt.subplots(figsize=figsize)
    ax.scatter(x, y)

    for x0, y0, path in zip(x, y,paths):
        ab = AnnotationBbox(self.get_image(path, zoom), (x0, y0), frameon=False)
        ax.add_artist(ab)

  def plot_rasterfairy_embeddings(self, embeddings, num, figsize=20, zoom=0.2, rand_seed=0):
    if num < len(embeddings):
      np.random.seed(rand_seed)
      idxs = np.random.choice(len(embeddings), num, replace=False)
    else:
      idxs = np.arange(len(embeddings))
    paths = [str(self.dataset.image_fns[i]) for i in idxs]

    grid_xy, shape = rasterfairy.transformPointCloud2D(embeddings[idxs])
    new_figsize = (figsize / 20 * shape[0], figsize / 20 * shape[1] )
    fig, ax = plt.subplots(figsize=new_figsize)
    x, y = grid_xy[:, 0], grid_xy[:, 1]
    ax.scatter(x, y)

    for x0, y0, path in zip(x, y,paths):
        ab = AnnotationBbox(self.get_image(path, zoom), (x0, y0), frameon=False)
        ax.add_artist(ab)

  def get_embedding_from_image_path(self, path):
    img = Image.open(path)
    img = self.center_crop(img)
    img = img.unsqueeze(0)
    img = img.to(self.device)
    with torch.inference_mode():
      model(img)
    return embedding[0].cpu()

  def plot_most_similar_images(self, path, num=10):
    emb = self.get_embedding_from_image_path(path)
    dists = torch.cdist(emb, self.embeddings)
    nearest_idx = dists.argsort(dim=1)[:, 0:num]
    print(nearest_idx.shape)
    nearest_paths = [self.dataset.image_fns[idx] for idx in nearest_idx[0]]
    images = [Image.open(path) for path in nearest_paths]
    fig, ax = plt.subplots(1, len(images), figsize=(30, 5))
    for i, image in enumerate(images):
      plt.subplot(1,  len(images), i+1)
      plt.imshow(image)
      plt.axis('off')


  def plot_path_between_two_images(self, path_a, path_b, pca_n_comp=2, k=10):
    emb_a = self.get_embedding_from_image_path(path_a)
    emb_b = self.get_embedding_from_image_path(path_b)
    pca_embs, reducer = self.get_reduced_embedding('pca', n_components=4)
    pca_embs = torch.tensor(pca_embs, dtype=torch.float32)
    emb_a = torch.tensor(reducer.transform(emb_a), dtype=torch.float32)
    emb_b = torch.tensor(reducer.transform(emb_b), dtype=torch.float32)
    self.plot_images_on_embedding_path(pca_embs, emb_a[0], emb_b[0], k=k, add_paths=[path_a, path_b])



visualizer = EmbeddingVisualizer(dataset, model, 'umap')
visualizer.embeddings

### 1.1 Plot Scatter of Images
- In default, it will visulzied 2D-reduced embeddings using UMAP
  - You can try `tsne` or `pca` using codes as below:
    - `visualizer.plot_embeddings_with_image(visualizer.get_reduced_embedding(method='tsne')[0], num=200, zoom=0.4, rand_seed=1)`

In [None]:
visualizer.plot_embeddings_with_image(visualizer.red_embs, num=200, zoom=0.4, rand_seed=10)

# Other options: 'pca', 'tsne', 'umap'
# visualizer.plot_embeddings_with_image(visualizer.get_reduced_embedding(method='tsne')[0], num=200, zoom=0.4, rand_seed=10)

In [None]:
# figsize for this function takes only scalar value.
# figsize 20 and zoom 0.2 is default settings
visualizer.plot_rasterfairy_embeddings(visualizer.red_embs, num=200, figsize=20, zoom=0.2, rand_seed=10)

# 2. Plot paths between images
- Depends on the density of your dataset, you can select different `n_components` of PCA
  - If you use too high `n_components` compared to number of images in the dataset, there could be no image embeddings at all between the two given image embeddings

In [None]:
pca_embedding = visualizer.get_reduced_embedding('pca', n_components=2)[0]
visualizer.plot_images_on_embedding_path(torch.tensor(pca_embedding), input_a=0, input_b=1, k=10)

### Use your own image
- You can load your own image to get embedding and plot similar images or images between two images
  - If you are using Colab, you can upload it via file-explorer, which is the panel in the left side

In [None]:
# TODO: you can use your own image path
your_image_path = ''
visualizer.plot_most_similar_images(your_image_path, num=10)

In [None]:
your_image_path_a = 'write down your image path here'
your_image_path_b = 'write down your image path here'

visualizer.plot_path_between_two_images(your_image_path_a, your_image_path_b)

## 3 (Optional). Fine tune Model and Visualize Again
- Different model sees different things from the same image.
  - If you fine tune the model to classify your own dataset (instead of ImageNet), the latent space will be changed
- It is not mandatory to split train/validation/test in strict way to observe latent space
  - But if you want to check how well is the fine tune going on, it is better to use train/test split, at least.

In [None]:
'''
TODO: Fill in your training and test set paths to the trainset and testset
'''

# trainset_path = 'artbench-10-imagefolder-split/train'
# testset_path = 'artbench-10-imagefolder-split/test'
trainset_path = 'images/train/'
testset_path = 'images/test/'


trainset = ImageSet(trainset_path)
testset = ImageSet(testset_path)

train_loader = torch.utils.data.DataLoader(trainset, batch_size=16, shuffle=True)
test_loader = torch.utils.data.DataLoader(testset, batch_size=16, shuffle=False)

new_model = torchvision.models.resnet18(pretrained=True)
new_model.fc = torch.nn.Linear(new_model.fc.in_features, len(dataset.classes))

new_model.fc.register_forward_hook(hook)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

optimizer = torch.optim.Adam(new_model.parameters(), lr=0.0003)
new_model = new_model.to(device)

criterion = torch.nn.CrossEntropyLoss()

In [None]:
# fine tune your model

def get_accuracy(model, data_loader, device):
  model.to(device)
  accuracy = 0
  for batch in data_loader:
    images, labels = batch
    images = images.to(device)
    labels = labels.to(device)
    with torch.inference_mode():
      out = model(images)
    accuracy += (out.argmax(dim=1) == labels).sum().item()
  return accuracy / len(data_loader.dataset)


n_epochs = 10
loss_record = []
accuracies = []

new_model.to(device)
for i in range(n_epochs):
  new_model.train()
  for batch in train_loader:
    images, labels = batch
    images = images.to(device)
    labels = labels.to(device)
    out = new_model(images)
    loss = criterion(out, labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    loss_record.append(loss.item())
  new_model.eval()
  accuracy = get_accuracy(new_model, test_loader, device)
  accuracies.append(accuracy)
  print(f'Epoch {i+1} accuracy: {accuracy}')


plt.subplot(2,1,1)
plt.plot(loss_record)
plt.title('Training Loss')
plt.subplot(2,1,2)
plt.plot(accuracies)
plt.title('Test Accuracy')

In [None]:
new_visualizer = EmbeddingVisualizer(dataset, new_model, 'umap')

In [None]:
new_visualizer.plot_embeddings_with_image(new_visualizer.red_embs, 200, zoom=0.4, rand_seed=10)

In [None]:
new_visualizer.plot_rasterfairy_embeddings(new_visualizer.red_embs, 500, figsize=20, zoom=0.2, rand_seed=10)