In [1]:
import torch
import torch.nn as nn
import torch.utils.data as td
import torchvision as tv
import torchvision.transforms as tf
import torch.nn.functional as F
import torchvision
import torch.optim as optim
from ckpt_manager import CheckpointManager
from tqdm import tqdm
import numpy as np

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
torch.cuda.device_count()

1

In [3]:
class ResNetDecBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, upsample=None):
        super().__init__()
        self.conv1 = nn.Sequential(
                        nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
                        nn.BatchNorm2d(in_channels))
        if stride != 1:
            self.conv2 = nn.Sequential(
                            nn.Upsample(scale_factor=stride),
                            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
                            nn.BatchNorm2d(out_channels))
            self.shortcut = nn.Sequential(
                            nn.Upsample(scale_factor=stride),
                            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1))
        else:
            self.conv2 = nn.Sequential(
                        nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
                        nn.BatchNorm2d(out_channels))
            self.shortcut = nn.Sequential()
        self.relu = nn.ReLU()
        self.out_channels = out_channels
        
    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out += self.shortcut(x)
        out = self.relu(out)
        return out

In [4]:
class ResNet18Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.inchannel = 64
        backbone = torchvision.models.resnet18(weights="DEFAULT")

        self.conv1 = nn.Sequential(
            backbone.conv1,
            backbone.bn1,
            backbone.relu,
            backbone.maxpool
        )
        self.layer1 = backbone.layer1
        self.layer2 = backbone.layer2
        self.layer3 = backbone.layer3
        self.layer4 = backbone.layer4
    
    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        return out

In [5]:
class ResNet18Decoder(nn.Module):
    def __init__(self, Block):
        super().__init__()
        self.inchannel = 512
        self.up = nn.Upsample(scale_factor=2)
        self.layer1 = self.make_layer(Block, 256, 2, stride=2)
        self.layer2 = self.make_layer(Block, 128, 2, stride=2)
        self.layer3 = self.make_layer(Block, 64, 2, stride=2)
        self.layer4 = self.make_layer(Block, 64, 2, stride=1)
        self.resize = nn.Sequential(
                            nn.Upsample(scale_factor=4),
                            nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1))
        
    def make_layer(self, block, channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.inchannel, channels, stride))
            self.inchannel = channels
        return nn.Sequential(*layers)
    
    def forward(self, x):
        out = x.view(-1, 512, 8, 8)
        out = self.up(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.resize(out)
        return out

In [6]:
class ResNetAutoencoder(nn.Module):
    def __init__(self, quantize_factor):
        super().__init__()
        self.enc = ResNet18Encoder()
        self.dec = ResNet18Decoder(ResNetDecBlock)
        self.quantize_factor = quantize_factor
    
    def forward(self, x):
        out = self.enc(x)
        out = torch.clamp(out, 0.0, 1.0)
        out = out + (1 / 2 ** self.quantize_factor) * (torch.rand_like(out) * 0.5 - 0.5)
        out = self.dec(out)
        return F.sigmoid(out)

In [7]:
transform = tf.Compose(
        [
         tf.ToTensor()
        ]
)

batch_size = 64

In [8]:
dataset = torchvision.datasets.ImageFolder('universal-dataset', transform=transform)
loader = td.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8)

In [18]:
# !rm -rf model_B=2
# !rm -rf model_B=256
# !rm -rf model_B=8
# !rm -rf training_checkpoints_B=2
# !rm -rf training_checkpoints_B=256
# !rm -rf training_checkpoints_B=8

In [19]:
!mkdir training_checkpoints_B=10
!mkdir training_checkpoints_B=2

In [20]:
model = ResNetAutoencoder(2).cuda()

In [21]:
manager = CheckpointManager(
    assets={
        'model' : model.state_dict()
    },
    directory='training_checkpoints_B=2',
    file_name='ResNetAutoencoder',
    maximum=10,
    file_format='pt'
)

In [22]:
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9)
criterion = nn.MSELoss()

model.train()
1

1

In [23]:
EPOCHS = 3
SAVE_LOG_ITERS = 100
losses = []
for epoch in range(EPOCHS):
    for i, (batch, _) in tqdm(enumerate(loader)):
        batch = batch.cuda()
        
        predicted = model(batch)
        
        loss = criterion(predicted, batch)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())
        
        if (i + 1) % SAVE_LOG_ITERS == 0:
            print(f"Epoch: {epoch} | iter: {i} | loss: {np.array(losses[-SAVE_LOG_ITERS:]).mean()}")
            manager.save()
    scheduler.step()

100it [00:42,  2.28it/s]

Epoch: 0 | iter: 99 | loss: 0.0355590483173728
Saved states to training_checkpoints_B=2/ResNetAutoencoder_1.pt


200it [01:23,  2.27it/s]

Epoch: 0 | iter: 199 | loss: 0.022535147368907927
Saved states to training_checkpoints_B=2/ResNetAutoencoder_2.pt


300it [02:04,  2.25it/s]

Epoch: 0 | iter: 299 | loss: 0.019365734197199346
Saved states to training_checkpoints_B=2/ResNetAutoencoder_3.pt


400it [02:46,  2.29it/s]

Epoch: 0 | iter: 399 | loss: 0.01797392622567713
Saved states to training_checkpoints_B=2/ResNetAutoencoder_4.pt


500it [03:27,  2.31it/s]

Epoch: 0 | iter: 499 | loss: 0.016411674367263915
Saved states to training_checkpoints_B=2/ResNetAutoencoder_5.pt


600it [04:09,  2.29it/s]

Epoch: 0 | iter: 599 | loss: 0.01515758628025651
Saved states to training_checkpoints_B=2/ResNetAutoencoder_6.pt


699it [04:50,  2.37it/s]

Epoch: 0 | iter: 699 | loss: 0.01469973341561854


700it [04:50,  2.07it/s]

Saved states to training_checkpoints_B=2/ResNetAutoencoder_7.pt


800it [05:32,  2.30it/s]

Epoch: 0 | iter: 799 | loss: 0.013909212425351142
Saved states to training_checkpoints_B=2/ResNetAutoencoder_8.pt


900it [06:13,  2.27it/s]

Epoch: 0 | iter: 899 | loss: 0.012989940559491516
Saved states to training_checkpoints_B=2/ResNetAutoencoder_9.pt


1000it [06:55,  2.28it/s]

Epoch: 0 | iter: 999 | loss: 0.01289048781618476
Saved states to training_checkpoints_B=2/ResNetAutoencoder_10.pt


1100it [07:36,  2.23it/s]

Epoch: 0 | iter: 1099 | loss: 0.012800270058214665
Saved states to training_checkpoints_B=2/ResNetAutoencoder_11.pt


1200it [08:18,  2.21it/s]

Epoch: 0 | iter: 1199 | loss: 0.012175543988123537
Saved states to training_checkpoints_B=2/ResNetAutoencoder_11.pt


1300it [09:00,  2.20it/s]

Epoch: 0 | iter: 1299 | loss: 0.011981275482103228
Saved states to training_checkpoints_B=2/ResNetAutoencoder_12.pt


1400it [09:42,  2.25it/s]

Epoch: 0 | iter: 1399 | loss: 0.01179324796423316
Saved states to training_checkpoints_B=2/ResNetAutoencoder_12.pt


1500it [10:24,  2.13it/s]

Epoch: 0 | iter: 1499 | loss: 0.011281449245288969
Saved states to training_checkpoints_B=2/ResNetAutoencoder_13.pt


1600it [11:06,  2.24it/s]

Epoch: 0 | iter: 1599 | loss: 0.011265694359317423
Saved states to training_checkpoints_B=2/ResNetAutoencoder_13.pt


1700it [11:47,  2.21it/s]

Epoch: 0 | iter: 1699 | loss: 0.011051364298909903
Saved states to training_checkpoints_B=2/ResNetAutoencoder_14.pt


1800it [12:29,  2.28it/s]

Epoch: 0 | iter: 1799 | loss: 0.010582346022129058
Saved states to training_checkpoints_B=2/ResNetAutoencoder_14.pt


1900it [13:10,  2.25it/s]

Epoch: 0 | iter: 1899 | loss: 0.010553488554432989
Saved states to training_checkpoints_B=2/ResNetAutoencoder_15.pt


2000it [13:52,  2.26it/s]

Epoch: 0 | iter: 1999 | loss: 0.010450115036219359
Saved states to training_checkpoints_B=2/ResNetAutoencoder_15.pt


2071it [14:21,  2.40it/s]
100it [00:41,  2.24it/s]

Epoch: 1 | iter: 99 | loss: 0.010190647235140205
Saved states to training_checkpoints_B=2/ResNetAutoencoder_16.pt


200it [01:23,  2.29it/s]

Epoch: 1 | iter: 199 | loss: 0.010102741103619337
Saved states to training_checkpoints_B=2/ResNetAutoencoder_16.pt


300it [02:04,  2.23it/s]

Epoch: 1 | iter: 299 | loss: 0.00978135091252625
Saved states to training_checkpoints_B=2/ResNetAutoencoder_17.pt


400it [02:46,  2.25it/s]

Epoch: 1 | iter: 399 | loss: 0.009659050046466292
Saved states to training_checkpoints_B=2/ResNetAutoencoder_17.pt


500it [03:28,  2.21it/s]

Epoch: 1 | iter: 499 | loss: 0.009665857595391571
Saved states to training_checkpoints_B=2/ResNetAutoencoder_18.pt


600it [04:09,  2.26it/s]

Epoch: 1 | iter: 599 | loss: 0.009516026130877436
Saved states to training_checkpoints_B=2/ResNetAutoencoder_18.pt


700it [04:51,  2.21it/s]

Epoch: 1 | iter: 699 | loss: 0.009208576139062643
Saved states to training_checkpoints_B=2/ResNetAutoencoder_19.pt


800it [05:33,  2.24it/s]

Epoch: 1 | iter: 799 | loss: 0.009325693282298743
Saved states to training_checkpoints_B=2/ResNetAutoencoder_19.pt


900it [06:15,  2.24it/s]

Epoch: 1 | iter: 899 | loss: 0.009173358106054366
Saved states to training_checkpoints_B=2/ResNetAutoencoder_20.pt


1000it [06:57,  2.27it/s]

Epoch: 1 | iter: 999 | loss: 0.008868359224870802
Saved states to training_checkpoints_B=2/ResNetAutoencoder_21.pt


1100it [07:38,  2.27it/s]

Epoch: 1 | iter: 1099 | loss: 0.008960101841948927
Saved states to training_checkpoints_B=2/ResNetAutoencoder_22.pt


1125it [07:49,  2.38it/s]

In [24]:
!mkdir model_B=2
torch.save(model.enc.state_dict(), 'model_B=2/encoder.model')
torch.save(model.dec.state_dict(), 'model_B=2/decoder.model')

In [None]:
model = ResNetAutoencoder(10).cuda()
manager = CheckpointManager(
    assets={
        'model' : model.state_dict()
    },
    directory='training_checkpoints_B=10',
    file_name='ResNetAutoencoder',
    maximum=10,
    file_format='pt'
)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9)
criterion = nn.MSELoss()

model.train()

EPOCHS = 3
SAVE_LOG_ITERS = 100
losses = []
for epoch in range(EPOCHS):
    for i, (batch, _) in tqdm(enumerate(loader)):
        batch = batch.cuda()
        
        predicted = model(batch)
        
        loss = criterion(predicted, batch)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())
        
        if (i + 1) % SAVE_LOG_ITERS == 0:
            print(f"Epoch: {epoch} | iter: {i} | loss: {np.array(losses[-SAVE_LOG_ITERS:]).mean()}")
            manager.save()
    scheduler.step()

!mkdir model_B=10
torch.save(model.enc.state_dict(), 'model_B=10/encoder.model')
torch.save(model.dec.state_dict(), 'model_B=10/decoder.model')