In [1]:
!pip install torch==1.7.1
!pip install torchvision==0.8.2
!pip install torch-mimicry

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torch==1.7.1
  Downloading torch-1.7.1-cp39-cp39-manylinux1_x86_64.whl (776.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m776.8/776.8 MB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch
  Attempting uninstall: torch
    Found existing installation: torch 1.13.1+cu116
    Uninstalling torch-1.13.1+cu116:
      Successfully uninstalled torch-1.13.1+cu116
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchvision 0.14.1+cu116 requires torch==1.13.1, but you have torch 1.7.1 which is incompatible.
torchtext 0.14.1 requires torch==1.13.1, but you have torch 1.7.1 which is incompatible.
torchaudio 0.13.1+cu116 requires torch==1.13.1, but you have torch 1.7.1 which is incompatible.[0m[31m
[0mSucces

In [2]:
import torch
import torch.optim as optim
import torch_mimicry as mmc

In [3]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [5]:
import sys
sys.path.append('/content/drive/MyDrive/SSD-GAN-main')
import models.ssd_sngan_32 as ssd_sngan

In [6]:
# 支持选择图片库 ('CIFAR10', 'CIFAR100', 'STL10', 'FashionMNIST' )，选择sample size，选择image大小，和grayscale
import os
from torch.utils.data import random_split
from torchvision import transforms, datasets

def dataset_split_shape(name, n = 1000, size = 32, grayScale = False, convert_tensor=True, transform_data = True, root='./datasets/', download=True):
  
  dataset_dir = os.path.join(root, f"{name}_{n}_{size}")
  if not os.path.exists(dataset_dir):
    os.makedirs(dataset_dir)

  if transform_data:
    transform_list = [transforms.ToTensor(),
                      transforms.Resize(size)]
    if grayScale:
      transform_list.append(transforms.Grayscale()) 
    if convert_tensor:
      transform_list.append(transforms.Normalize((0.5, ), (0.5, )))
  else:
    transform_list = []

  transformer = transforms.Compose(transform_list)

  if name == 'CIFAR10': 
    dataset = datasets.CIFAR10(
      root=dataset_dir,
      download=download,
      transform=transforms.Compose(transform_list)
  )
  elif name == 'CIFAR100': 
    dataset = datasets.CIFAR100(
      root=dataset_dir,
      download=download,
      transform=transforms.Compose(transform_list)
  )
  elif name == 'STL10': #与原package相同，默认选择unlabeled的数据
    dataset = datasets.STL10(
      root=dataset_dir,
      download=download,
      split='unlabeled',
      transform=transforms.Compose(transform_list)
  )
  elif name == 'FashionMNIST': 
    dataset = datasets.FashionMNIST(
      root=dataset_dir,
      download=download,
      transform=transforms.Compose(transform_list)
  )
  else:
    print("invalid name")
    return 
  if n is None or n > len(dataset):
    return dataset
  generator1 = torch.Generator().manual_seed(42)
  a, b = random_split(dataset, [n, len(dataset)-n], generator = generator1)
  return a

In [7]:
#选取CIFAR里面的500张图，size改为32*32
CIFAR10_500_32 = dataset_split_shape('CIFAR10', n = 500, size = 32)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./datasets/CIFAR10_500_32/cifar-10-python.tar.gz


0it [00:00, ?it/s]

Extracting ./datasets/CIFAR10_500_32/cifar-10-python.tar.gz to ./datasets/CIFAR10_500_32


In [8]:
print('check dataset')
print(f"number of images: {len(CIFAR10_500_32)}")
print(f"shape of images: {CIFAR10_500_32[0][0].shape}")

check dataset
number of images: 500
shape of images: torch.Size([3, 32, 32])


In [10]:
#建立dataloader
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
CIFAR10_500_32_dataloader = torch.utils.data.DataLoader(
    CIFAR10_500_32, batch_size=64, shuffle=True, num_workers=4)

In [11]:
# Define models and optimizers
netG = ssd_sngan.SSD_SNGANGenerator32().to(device)
netD = ssd_sngan.SSD_SNGANDiscriminator32().to(device)
optD = optim.Adam(netD.parameters(), 2e-4, betas=(0.0, 0.9))
optG = optim.Adam(netG.parameters(), 2e-4, betas=(0.0, 0.9))

# Start training
trainer = mmc.training.Trainer(
    netD=netD,
    netG=netG,
    optD=optD,
    optG=optG,
    n_dis=1000,
    num_steps=5, # number of iterations
    lr_decay='linear',
    dataloader= CIFAR10_500_32_dataloader,
    log_dir='./log/CIFAR10_500_32', #自定义地址
    device=device)

In [12]:
trainer.train()

INFO: Starting training from global step 0...


  fft = torch.rfft(x_gray,2,onesided=False)


INFO: [Epoch 1/1][Global Step: 1/5] 
| D(G(z)): 1.6799
| D(x): 2.0503
| errC: 1.9801
| errD: 2.6799
| errG: 4.2677
| lr_D: 0.00016
| lr_G: 0.00016
| (73.7939 sec/idx)
INFO: Saving checkpoints from keyboard interrupt...
INFO: Training Ended.


In [None]:
#如果metric是fid, 则需自定义stats_file,num_samples = number of real images;
def create_stats_file(log_dir, num_real_samples, seed, dataset, metric):
  stats_dir = os.path.join(log_dir, 'metrics', metric, 'statistics')
  if not os.path.exists(stats_dir):
    os.makedirs(stats_dir)

  stats_file = os.path.join(
            stats_dir,
            "fid_stats_{}_{}k_run_{}.npz".format(dataset, num_real_samples // 1000,
                                                 seed))
  return stats_file

#如果metric是kid: 
def create_feat_file(log_dir, num_samples, seed, dataset, metric):
  stats_dir = os.path.join(log_dir, 'metrics', metric, 'statistics')
  if not os.path.exists(stats_dir):
    os.makedirs(stats_dir)

  stats_file = os.path.join(
            stats_dir,
            "fid_stats_{}_{}k_run_{}.npz".format(dataset, num_samples // 1000,
                                                 seed))
  return stats_file



In [None]:
#fid举例
mmc.metrics.evaluate(
    metric='fid',
    log_dir='./log/CIFAR10_500_32',
    netG=netG,
    num_real_samples = 500,
    num_fake_samples = 500, 
    dataset = CIFAR10_500_32,
    evaluate_step=1000,
    start_seed=0,
    num_runs=1,
    device=device,
    stats_file = create_stats_file('./log/CIFAR10_500_32', 20, 0, CIFAR10_500_32, 'fid'))

In [None]:
# kid举例
mmc.metrics.evaluate(
    metric='kid',
    log_dir='./log/CIFAR10_500_32',
    netG=netG,
    num_samples = 500,
    dataset = CIFAR10_500_32,
    evaluate_step=1000,
    start_seed=0,
    num_runs=1,
    device=device,
    feat_file = create_feat_file('./log/CIFAR10_500_32', 80, 0, CIFAR10_500_32, 'kid'))

In [None]:
#IC举例
mmc.metrics.evaluate(
    metric='inception_score',
    log_dir='./log/CIFAR10_500_32',
    netG=netG,
    num_samples = 500,
    evaluate_step=1000,
    start_seed=0,
    num_runs=1,
    device=device)


In [None]:
Log=mmc.training.Logger(log_dir='./log/CIFAR10_500_32', num_steps=1000, dataset_size=500, device=device)

NameError: ignored

In [None]:
Log.vis_images(netG=netG,global_step=1000)