Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add example of self-supervised SimCLR training - V2 #50

Merged
merged 9 commits into from
Dec 21, 2021
45 changes: 45 additions & 0 deletions examples/simclr_cifar10_data_parallel/NT_Xentloss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.registry import LOSSES
from torch.nn.modules.linear import Linear

@LOSSES.register_module
class NT_Xentloss(nn.Module):
def __init__(self, temperature=0.5):
super().__init__()
self.temperature = temperature

def forward(self, z1, z2, label):
z1 = F.normalize(z1, dim=1)
z2 = F.normalize(z2, dim=1)
N, Z = z1.shape
device = z1.device
representations = torch.cat([z1, z2], dim=0)
similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1)
l_pos = torch.diag(similarity_matrix, N)
r_pos = torch.diag(similarity_matrix, -N)
positives = torch.cat([l_pos, r_pos]).view(2 * N, 1)
diag = torch.eye(2*N, dtype=torch.bool, device=device)
diag[N:,:N] = diag[:N,N:] = diag[:N,:N]

negatives = similarity_matrix[~diag].view(2*N, -1)

logits = torch.cat([positives, negatives], dim=1)
logits /= self.temperature

labels = torch.zeros(2*N, device=device, dtype=torch.int64)

loss = F.cross_entropy(logits, labels, reduction='sum')
return loss / (2 * N)


if __name__=='__main__':
criterion = NT_Xentloss()
net = Linear(256,512)
output = [net(torch.randn(512,256)), net(torch.randn(512,256))]
label = [torch.randn(512)]
loss = criterion(*output, *label)
print(loss)


30 changes: 30 additions & 0 deletions examples/simclr_cifar10_data_parallel/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Overview

Here is an example of applying [PreAct-ResNet18](https://arxiv.org/abs/1603.05027) to train [SimCLR](https://arxiv.org/abs/2002.05709) on CIFAR10.
SimCLR is a kind of self-supervised representation learning algorithm which learns generic representations of images on an unlabeled dataset. The generic representations are learned by simultaneously maximizing agreement between differently transformed views of the same image and minimizing agreement between transformed views of different images, following a method called contrastive learning. Updating the parameters of a neural network using this contrastive objective causes representations of corresponding views to “attract” each other, while representations of non-corresponding views “repel” each other. A more detailed description of SimCLR is available [here](https://ai.googleblog.com/2020/04/advancing-self-supervised-and-semi.html).

The training process consists of two phases: (1) self-supervised representation learning: the model which acts as a feature extractor is trained exactly as described above; and (2) linear evaluation: to evaluate how well representations are learned, generally a linear classifier is added on top of the trained feature extractor in phase 1. The linear classifier is trained with a labeled dataset in a conventional supervised manner, while parameters of the feature extractor keep fixed. This process is called linear evaluation.

# How to run
The training commands are specified in:
```shell
bash train.sh
```
Before running, you can specify the experiment name (folders with the same name will be created in `ckpt` to save checkpoints and in `tb_logs` to save the tensorboard file) and other training hyperparameters in `config.py`. By default CIFAR10 dataset will be downloaded automatically and saved in `./dataset`. Note that `LOG_NAME` in `le_config.py` should be the same as that in `config.py`.

Besides linear evaluation, you can also visualize the distribution of learned representations. A script is provided which first extracts representations and then visualizes them with [t-SNE](https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html). t-SNE is a good tool to visualize high-dimensional data. It converts similarities between data points to joint probabilities and tries to minimize the Kullback-Leibler divergence between the joint probabilities of the low-dimensional embedding and the high-dimensional data. You can directly run the script by (remember modifying `log_name` and `epoch` to specify the model in which experiment folder and of which training epoch to load):
```python
python visualization.py
```

# Results
The loss curve of SimCLR self-supervised training is as follows:
![SimCLR Loss Curve](./results/ssl_loss.png)
The loss curve of linear evaluation is as follows:
![Linear Evaluation Loss Curve](./results/linear_eval_loss.png)
The accuracy curve of linear evaluation is as follows:
![Linear Evaluation Accuracy](./results/linear_eval_acc.png)
The t-SNE of the training set of CIFAR10 is as follows:
![train tSNE](./results/train_tsne.png)
The t-SNE of the test set of CIFAR10 is as follows:
![test tSNE](./results/test_tsne.png)
32 changes: 32 additions & 0 deletions examples/simclr_cifar10_data_parallel/augmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from torchvision.transforms import transforms

class SimCLRTransform():
def __init__(self):
self.transform = transforms.Compose([
transforms.RandomResizedCrop(size=32, scale=(0.2, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.RandomApply([transforms.GaussianBlur(kernel_size=32//20*2+1, sigma=(0.1, 2.0))], p=0.5),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
])

def __call__(self, x):
x1 = self.transform(x)
x2 = self.transform(x)
return x1, x2


class LeTransform():
def __init__(self):
self.transform = transforms.Compose([
transforms.RandomResizedCrop(size=32, scale=(0.2, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
])

def __call__(self, x):
x = self.transform(x)
return x
23 changes: 23 additions & 0 deletions examples/simclr_cifar10_data_parallel/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from colossalai.amp import AMP_TYPE


LOG_NAME = 'cifar-simclr'

BATCH_SIZE = 512
NUM_EPOCHS = 801
LEARNING_RATE = 0.03*BATCH_SIZE/256
WEIGHT_DECAY = 0.0005
MOMENTUM = 0.9


fp16 = dict(
mode=AMP_TYPE.TORCH,
)

dataset = dict(
root='./dataset',
)

gradient_accumulation=2
gradient_clipping=1.0

23 changes: 23 additions & 0 deletions examples/simclr_cifar10_data_parallel/le_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from colossalai.amp import AMP_TYPE


LOG_NAME = 'cifar-simclr'
EPOCH = 800

BATCH_SIZE = 512
NUM_EPOCHS = 51
LEARNING_RATE = 0.03*BATCH_SIZE/256
WEIGHT_DECAY = 0.0005
MOMENTUM = 0.9


fp16 = dict(
mode=AMP_TYPE.TORCH,
)

dataset = dict(
root='./dataset',
)

gradient_accumulation=1
gradient_clipping=1.0
178 changes: 178 additions & 0 deletions examples/simclr_cifar10_data_parallel/models/Backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet34, resnet50, resnet101, resnet152


def backbone(model, **kwargs):
assert model in ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'], "current version only support resnet18 ~ resnet152"
if model == 'resnet18':
net = ResNet(PreActBlock, [2,2,2,2], **kwargs)
else:
net = eval(f"{model}(**kwargs)")
net.output_dim = net.fc.in_features
net.fc = torch.nn.Identity()
return net


def conv3x3(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


class BasicBlock(nn.Module):
expansion = 1

def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(in_planes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)

self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)

def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out


class PreActBlock(nn.Module):
'''Pre-activation version of the BasicBlock.'''
expansion = 1

def __init__(self, in_planes, planes, stride=1):
super(PreActBlock, self).__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.conv1 = conv3x3(in_planes, planes, stride)
self.bn2 = nn.BatchNorm2d(planes)
self.conv2 = conv3x3(planes, planes)

self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
)

def forward(self, x):
out = F.relu(self.bn1(x))
shortcut = self.shortcut(out)
out = self.conv1(out)
out = self.conv2(F.relu(self.bn2(out)))
out += shortcut
return out


class Bottleneck(nn.Module):
expansion = 4

def __init__(self, in_planes, planes, stride=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.expansion*planes)

self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)

def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x)
out = F.relu(out)
return out


class PreActBottleneck(nn.Module):
'''Pre-activation version of the original Bottleneck module.'''
expansion = 4

def __init__(self, in_planes, planes, stride=1):
super(PreActBottleneck, self).__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)

self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
)

def forward(self, x):
out = F.relu(self.bn1(x))
shortcut = self.shortcut(out)
out = self.conv1(out)
out = self.conv2(F.relu(self.bn2(out)))
out = self.conv3(F.relu(self.bn3(out)))
out += shortcut
return out


class ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10):
super(ResNet, self).__init__()
self.in_planes = 64

self.conv1 = conv3x3(3,64)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.fc = nn.Linear(512*block.expansion, num_classes)

def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)

def forward(self, x, lin=0, lout=5):
out = x
if lin < 1 and lout > -1:
out = self.conv1(out)
out = self.bn1(out)
out = F.relu(out)
if lin < 2 and lout > 0:
out = self.layer1(out)
if lin < 3 and lout > 1:
out = self.layer2(out)
if lin < 4 and lout > 2:
out = self.layer3(out)
if lin < 5 and lout > 3:
out = self.layer4(out)
if lout > 4:
out = F.avg_pool2d(out, 4)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out

def debug():
net = backbone('resnet18', pretrained=True)
x = torch.randn(4,3,32,32)
y = net(x)
print(y.size())

if __name__ == '__main__':
debug()
19 changes: 19 additions & 0 deletions examples/simclr_cifar10_data_parallel/models/linear_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .Backbone import backbone

class Linear_eval(nn.Module):

def __init__(self, model='resnet18', class_num=10, **kwargs):
super().__init__()

self.backbone = backbone(model, **kwargs)
self.backbone.requires_grad_(False)
self.fc = nn.Linear(self.backbone.output_dim, class_num)

def forward(self, x):

out = self.backbone(x)
out = self.fc(out)
return out
36 changes: 36 additions & 0 deletions examples/simclr_cifar10_data_parallel/models/simclr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .Backbone import backbone

class projection_MLP(nn.Module):
def __init__(self, in_dim, out_dim=256):
super().__init__()
hidden_dim = in_dim
self.layer1 = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.ReLU(inplace=True)
)
self.layer2 = nn.Linear(hidden_dim, out_dim)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
return x

class SimCLR(nn.Module):

def __init__(self, model='resnet18', **kwargs):
super().__init__()

self.backbone = backbone(model, **kwargs)
self.projector = projection_MLP(self.backbone.output_dim)
self.encoder = nn.Sequential(
self.backbone,
self.projector
)

def forward(self, x1, x2):

z1 = self.encoder(x1)
z2 = self.encoder(x2)
return z1, z2