Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add example of self-supervised SimCLR training - V2 #50

Merged
merged 9 commits into from
Dec 21, 2021
45 changes: 45 additions & 0 deletions examples/simclr_cifar10_data_parallel/NT_Xentloss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.registry import LOSSES
from torch.nn.modules.linear import Linear

@LOSSES.register_module
class NT_Xentloss(nn.Module):
def __init__(self, temperature=0.5):
super().__init__()
self.temperature = temperature

def forward(self, z1, z2, label):
z1 = F.normalize(z1, dim=1)
z2 = F.normalize(z2, dim=1)
N, Z = z1.shape
device = z1.device
representations = torch.cat([z1, z2], dim=0)
similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1)
l_pos = torch.diag(similarity_matrix, N)
r_pos = torch.diag(similarity_matrix, -N)
positives = torch.cat([l_pos, r_pos]).view(2 * N, 1)
diag = torch.eye(2*N, dtype=torch.bool, device=device)
diag[N:,:N] = diag[:N,N:] = diag[:N,:N]

negatives = similarity_matrix[~diag].view(2*N, -1)

logits = torch.cat([positives, negatives], dim=1)
logits /= self.temperature

labels = torch.zeros(2*N, device=device, dtype=torch.int64)

loss = F.cross_entropy(logits, labels, reduction='sum')
return loss / (2 * N)


if __name__=='__main__':
criterion = NT_Xentloss()
net = Linear(256,512)
output = [net(torch.randn(512,256)), net(torch.randn(512,256))]
label = [torch.randn(512)]
loss = criterion(*output, *label)
print(loss)


28 changes: 28 additions & 0 deletions examples/simclr_cifar10_data_parallel/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Overview

Here is an example of applying PreAct-ResNet18 to train SimCLR on CIFAR10.
We use 1x RTX 3090 in this example.
The training process consists of two phases: (1) self-supervised training; and (2) linear evaluation.

# How to run
The training commands are specified in:
```shell
bash train.sh
```

Besides linear evaluation, you can also visualize the learned representations by (remember modifying `log_name` and `epoch` in advance):
```python
python visualization.py
```

# Results
The loss curve of SimCLR self-supervised training is as follows:
![SimCLR Loss Curve](./results/ssl_loss.png)
The loss curve of linear evaluation is as follows:
![Linear Evaluation Loss Curve](./results/linear_eval_loss.png)
The accuracy curve of linear evaluation is as follows:
![Linear Evaluation Accuracy](./results/linear_eval_acc.png)
The t-SNE of the training set of CIFAR10 is as follows:
![train tSNE](./results/train_tsne.png)
The t-SNE of the test set of CIFAR10 is as follows:
![test tSNE](./results/test_tsne.png)
32 changes: 32 additions & 0 deletions examples/simclr_cifar10_data_parallel/augmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from torchvision.transforms import transforms

class SimCLRTransform():
def __init__(self):
self.transform = transforms.Compose([
transforms.RandomResizedCrop(size=32, scale=(0.2, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.RandomApply([transforms.GaussianBlur(kernel_size=32//20*2+1, sigma=(0.1, 2.0))], p=0.5),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
])

def __call__(self, x):
x1 = self.transform(x)
x2 = self.transform(x)
return x1, x2


class LeTransform():
def __init__(self):
self.transform = transforms.Compose([
transforms.RandomResizedCrop(size=32, scale=(0.2, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
])

def __call__(self, x):
x = self.transform(x)
return x
23 changes: 23 additions & 0 deletions examples/simclr_cifar10_data_parallel/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from colossalai.amp import AMP_TYPE


LOG_NAME = 'cifar-simclr'

BATCH_SIZE = 512
NUM_EPOCHS = 801
LEARNING_RATE = 0.03*BATCH_SIZE/256
WEIGHT_DECAY = 0.0005
MOMENTUM = 0.9


fp16 = dict(
mode=AMP_TYPE.TORCH,
)

dataset = dict(
root='../../../../../datasets',
)

gradient_accumulation=2
gradient_clipping=1.0

23 changes: 23 additions & 0 deletions examples/simclr_cifar10_data_parallel/le_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from colossalai.amp import AMP_TYPE


LOG_NAME = 'cifar-simclr'
EPOCH = 800

BATCH_SIZE = 512
NUM_EPOCHS = 51
LEARNING_RATE = 0.03*BATCH_SIZE/256
WEIGHT_DECAY = 0.0005
MOMENTUM = 0.9


fp16 = dict(
mode=AMP_TYPE.TORCH,
)

dataset = dict(
root='../../../../../datasets',
)

gradient_accumulation=1
gradient_clipping=1.0
178 changes: 178 additions & 0 deletions examples/simclr_cifar10_data_parallel/models/Backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet34, resnet50, resnet101, resnet152


def backbone(model, **kwargs):
assert model in ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'], "current version only support resnet18 ~ resnet152"
if model == 'resnet18':
net = ResNet(PreActBlock, [2,2,2,2], **kwargs)
else:
net = eval(f"{model}(**kwargs)")
net.output_dim = net.fc.in_features
net.fc = torch.nn.Identity()
return net


def conv3x3(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


class BasicBlock(nn.Module):
expansion = 1

def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(in_planes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)

self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)

def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out


class PreActBlock(nn.Module):
'''Pre-activation version of the BasicBlock.'''
expansion = 1

def __init__(self, in_planes, planes, stride=1):
super(PreActBlock, self).__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.conv1 = conv3x3(in_planes, planes, stride)
self.bn2 = nn.BatchNorm2d(planes)
self.conv2 = conv3x3(planes, planes)

self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
)

def forward(self, x):
out = F.relu(self.bn1(x))
shortcut = self.shortcut(out)
out = self.conv1(out)
out = self.conv2(F.relu(self.bn2(out)))
out += shortcut
return out


class Bottleneck(nn.Module):
expansion = 4

def __init__(self, in_planes, planes, stride=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.expansion*planes)

self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)

def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x)
out = F.relu(out)
return out


class PreActBottleneck(nn.Module):
'''Pre-activation version of the original Bottleneck module.'''
expansion = 4

def __init__(self, in_planes, planes, stride=1):
super(PreActBottleneck, self).__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)

self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
)

def forward(self, x):
out = F.relu(self.bn1(x))
shortcut = self.shortcut(out)
out = self.conv1(out)
out = self.conv2(F.relu(self.bn2(out)))
out = self.conv3(F.relu(self.bn3(out)))
out += shortcut
return out


class ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10):
super(ResNet, self).__init__()
self.in_planes = 64

self.conv1 = conv3x3(3,64)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.fc = nn.Linear(512*block.expansion, num_classes)

def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)

def forward(self, x, lin=0, lout=5):
out = x
if lin < 1 and lout > -1:
out = self.conv1(out)
out = self.bn1(out)
out = F.relu(out)
if lin < 2 and lout > 0:
out = self.layer1(out)
if lin < 3 and lout > 1:
out = self.layer2(out)
if lin < 4 and lout > 2:
out = self.layer3(out)
if lin < 5 and lout > 3:
out = self.layer4(out)
if lout > 4:
out = F.avg_pool2d(out, 4)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out

def debug():
net = backbone('resnet18', pretrained=True)
x = torch.randn(4,3,32,32)
y = net(x)
print(y.size())

if __name__ == '__main__':
debug()
19 changes: 19 additions & 0 deletions examples/simclr_cifar10_data_parallel/models/linear_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .Backbone import backbone

class Linear_eval(nn.Module):

def __init__(self, model='resnet18', class_num=10, **kwargs):
super().__init__()

self.backbone = backbone(model, **kwargs)
self.backbone.requires_grad_(False)
self.fc = nn.Linear(self.backbone.output_dim, class_num)

def forward(self, x):

out = self.backbone(x)
out = self.fc(out)
return out
36 changes: 36 additions & 0 deletions examples/simclr_cifar10_data_parallel/models/simclr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .Backbone import backbone

class projection_MLP(nn.Module):
def __init__(self, in_dim, out_dim=256):
super().__init__()
hidden_dim = in_dim
self.layer1 = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.ReLU(inplace=True)
)
self.layer2 = nn.Linear(hidden_dim, out_dim)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
return x

class SimCLR(nn.Module):

def __init__(self, model='resnet18', **kwargs):
super().__init__()

self.backbone = backbone(model, **kwargs)
self.projector = projection_MLP(self.backbone.output_dim)
self.encoder = nn.Sequential(
self.backbone,
self.projector
)

def forward(self, x1, x2):

z1 = self.encoder(x1)
z2 = self.encoder(x2)
return z1, z2