<a href="https://colab.research.google.com/github/hits-sdo/hits-sdo-similaritysearch/blob/ss_dataloader/search_simsiam/simsiam_example_notebook_HITS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/hits-sdo/hits-sdo-similaritysearch

Cloning into 'hits-sdo-similaritysearch'...
remote: Enumerating objects: 309, done.[K
remote: Counting objects: 100% (309/309), done.[K
remote: Compressing objects: 100% (265/265), done.[K
remote: Total 309 (delta 61), reused 275 (delta 36), pack-reused 0[K
Receiving objects: 100% (309/309), 5.05 MiB | 16.95 MiB/s, done.
Resolving deltas: 100% (61/61), done.


In [None]:
!pip install lightly

In [None]:
import math

import numpy as np
import torch
import torch.nn as nn
import torchvision

from lightly.data import ImageCollateFunction, LightlyDataset, collate
from lightly.loss import NegativeCosineSimilarity
from lightly.models.modules.heads import SimSiamPredictionHead, SimSiamProjectionHead

In [None]:
# download data
!gdown 15C5spf1la7L09kvWXll2qt67Ec0rwLsY
# unzip data
!tar -zxf aia_171_color_1perMonth.tar.gz && rm aia_171_color_1perMonth.tar.gz

Downloading...
From: https://drive.google.com/uc?id=15C5spf1la7L09kvWXll2qt67Ec0rwLsY
To: /content/aia_171_color_1perMonth.tar.gz
100% 146M/146M [00:01<00:00, 93.0MB/s]


In [None]:
# path_to_data = '/content/gdrive/MyDrive/HITS/aia_171_color_1perMonth/'
path_to_data = '/content/aia_171_color_1perMonth'

In [None]:
num_workers = 8 # How many process giving model to train -- similar to threading
batch_size = 32 # A subset of files that the model sees to update it's parameters
seed = 1 # Seed for random generator for reproducability
epochs = 50 # How many times we go through our entire data set
input_size = 120 #The number of pixels in x or y

# dimension of the embeddings
num_ftrs = 512 
# dimension of the output of the prediction and projection heads
out_dim = proj_hidden_dim = 512
# the prediction head uses a bottleneck architecture
pred_hidden_dim = 128

In [None]:
# seed torch and numpy 
# used for reproducibility in creating the model
torch.manual_seed(0)
np.random.seed(0)

In [None]:
# define the augmentations for self-supervised learning
collate_fn = ImageCollateFunction(
    input_size=input_size,
    # require invariance to flips and rotations
    hf_prob=0.5,
    vf_prob=0.5,
    rr_prob=0.5,
    # satellite images are all taken from the same height
    # so we use only slight random cropping
    min_scale=0.5,
    # use a weak color jitter for invariance w.r.t small color changes
    cj_prob=0.2,
    cj_bright=0.1,
    cj_contrast=0.1,
    cj_hue=0.1,
    cj_sat=0.1,
)

#test for the collate function
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
img = np.zeros((128,128,3))
img = torchvision.transforms.ToPILImage()(np.uint8(255*img))

input = [(img, 0, 'my-image.png')]

output = collate_fn(input)

(img_t0, img_t1), label, filename = output

# print(img_t0.shape, img_t1.shape)



# create a lightly dataset for training, since the augmentations are handled
# by the collate function, there is no need to apply additional ones here
dataset_train_simsiam = LightlyDataset(input_dir=path_to_data)

#3283 x 32 = 10506
print(len(dataset_train_simsiam))
# returns image, folder num, tile name
print(dataset_train_simsiam[800])


# create a dataloader for training
dataloader_train_simsiam = torch.utils.data.DataLoader(
    dataset_train_simsiam,
    batch_size=batch_size,
    shuffle=True,           # data reshuffled at every epoch if True
    collate_fn=collate_fn,  # constructs function
    drop_last=True,         # If want to merge datasets (optional) - mostly used when batches are loaded from map-styled datasets.
    num_workers=num_workers,
)

# create a torchvision transformation for embedding the dataset after training
# here, we resize the images to match the input size during training and apply
# a normalization of the color channel based on statistics from imagenet
test_transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize((input_size, input_size)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            mean=collate.imagenet_normalize["mean"],
            std=collate.imagenet_normalize["std"],
        ),
    ]
)

# create a lightly dataset for embedding
dataset_test = LightlyDataset(input_dir=path_to_data, transform=test_transforms)

# create a dataloader for embedding
dataloader_test = torch.utils.data.DataLoader(
    dataset_test,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers,
)

105056
(<PIL.Image.Image image mode=RGB size=128x128 at 0x7FE85715F970>, 1, '20100703_000036_aia.lev1_euv_12s_4k/tiles/20100703_000036_aia.lev1_euv_12s_4k_tile_1024_2944.jpg')


In [None]:
class SimSiam(nn.Module):
    def __init__(self, backbone, num_ftrs, proj_hidden_dim, pred_hidden_dim, out_dim):
        super().__init__()
        self.backbone = backbone
        self.projection_head = SimSiamProjectionHead(num_ftrs, proj_hidden_dim, out_dim)
        self.prediction_head = SimSiamPredictionHead(out_dim, pred_hidden_dim, out_dim)

    def forward(self, x):
        # get representations
        f = self.backbone(x).flatten(start_dim=1)
        # get projections
        z = self.projection_head(f)
        # get predictions
        p = self.prediction_head(z)
        # stop gradient
        z = z.detach()
        return z, p


# we use a pretrained resnet for this tutorial to speed
# up training time but you can also train one from scratch
resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = SimSiam(backbone, num_ftrs, proj_hidden_dim, pred_hidden_dim, out_dim)

In [None]:
# SimSiam uses a symmetric negative cosine similarity loss
criterion = NegativeCosineSimilarity()

# scale the learning rate
lr = 0.05 * batch_size / 256
# use SGD with momentum and weight decay
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

avg_loss = 0.0
avg_output_std = 0.0
for e in range(epochs):
    batch_count = 0
    for (x0, x1), _, _ in dataloader_train_simsiam:
        # move images to the gpu
        x0 = x0.to(device)
        x1 = x1.to(device)

        # run the model on both transforms of the images
        # we get projections (z0 and z1) and
        # predictions (p0 and p1) as output
        z0, p0 = model(x0)
        z1, p1 = model(x1)

        # apply the symmetric negative cosine similarity
        # and run backpropagation
        loss = 0.5 * (criterion(z0, p1) + criterion(z1, p0))
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        # calculate the per-dimension standard deviation of the outputs
        # we can use this later to check whether the embeddings are collapsing
        output = p0.detach()
        output = torch.nn.functional.normalize(output, dim=1)

        output_std = torch.std(output, 0)
        output_std = output_std.mean()
        print(batch_count)
        batch_count += 1
        print(loss)
        if(batch_count == 10):
          break
        # use moving averages to track the loss and standard deviation
        w = 0.9
        avg_loss = w * avg_loss + (1 - w) * loss.item()
        avg_output_std = w * avg_output_std + (1 - w) * output_std.item()

    # the level of collapse is large if the standard deviation of the l2
    # normalized output is much smaller than 1 / sqrt(dim)
    collapse_level = max(0.0, 1 - math.sqrt(out_dim) * avg_output_std)
    # print intermediate results
    print(
        f"[Epoch {e:3d}] "
        f"Loss = {avg_loss:.2f} | "
        f"Collapse Level: {collapse_level:.2f} / 1.00"
    )



0
tensor(-0.0182, grad_fn=<MulBackward0>)
1
tensor(-0.0187, grad_fn=<MulBackward0>)
2
tensor(-0.0224, grad_fn=<MulBackward0>)
3
tensor(-0.0387, grad_fn=<MulBackward0>)
4
tensor(-0.0412, grad_fn=<MulBackward0>)
5
tensor(-0.0397, grad_fn=<MulBackward0>)
6
tensor(-0.0404, grad_fn=<MulBackward0>)
7
tensor(-0.0519, grad_fn=<MulBackward0>)
8
tensor(-0.0506, grad_fn=<MulBackward0>)
9
tensor(-0.0369, grad_fn=<MulBackward0>)
[Epoch   0] Loss = -0.02 | Collapse Level: 0.50 / 1.00
0
tensor(-0.0756, grad_fn=<MulBackward0>)
1
tensor(-0.0771, grad_fn=<MulBackward0>)
2
tensor(-0.0692, grad_fn=<MulBackward0>)
3
tensor(-0.0850, grad_fn=<MulBackward0>)
4
tensor(-0.0870, grad_fn=<MulBackward0>)
5
tensor(-0.0913, grad_fn=<MulBackward0>)
6
tensor(-0.1101, grad_fn=<MulBackward0>)
7
tensor(-0.1122, grad_fn=<MulBackward0>)
8
tensor(-0.1267, grad_fn=<MulBackward0>)
9
tensor(-0.1142, grad_fn=<MulBackward0>)
[Epoch   1] Loss = -0.07 | Collapse Level: 0.31 / 1.00
0
tensor(-0.1138, grad_fn=<MulBackward0>)
1
tensor

KeyboardInterrupt: ignored

In [None]:
%cd hits-sdo-similaritysearch/search_simsiam/

/content/hits-sdo-similaritysearch/search_simsiam


In [None]:
!git checkout ss_dataloader

Branch 'ss_dataloader' set up to track remote branch 'ss_dataloader' from 'origin'.
Switched to a new branch 'ss_dataloader'


In [None]:
# Confirm that branch is up to date
!git log --oneline

[33me47b776[m[33m ([m[1;36mHEAD -> [m[1;32mss_dataloader[m[33m, [m[1;31morigin/ss_dataloader[m[33m)[m team-sunbird used stitch_adj_images to fill voids in two augmented views
[33macdcd3c[m[33m ([m[1;31morigin/main[m[33m, [m[1;31morigin/HEAD[m[33m, [m[1;32mmain[m[33m)[m Merge pull request #1 from hits-sdo/initialization
[33mb6b76a7[m[33m ([m[1;31morigin/initialization[m[33m)[m Make main the default branch in the notebook setup
[33m3e41ea2[m Added download to setup notebook
[33m7135c8a[m replaced 'search-utils' with 'search_utils'
[33m8c2014c[m replaced '-' with '_' in module names
[33m97a1002[m added search_utils to sys.path
[33mdd9c215[m added search-utils to sys.path
[33m0702056[m renamed folder module utils to search-utils
[33m68990e6[m renamed folder module utils to search-utils
[33m23f7e1e[m renamed folder module utils to search-utils
[33mcae0aef[m renamed folder module utils to search-utils
[33m341a9e8[m two augmentations f