In [13]:
#!/usr/bin/env python
import argparse
import builtins
import os
import shutil
import copy

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim
import torchvision.models as models

import sys
sys.path.extend(['..', '.'])
from datasets.dataset_tinyimagenet import load_train, load_val_loader, num_classes_dict
from tools.store import ExperimentLogWriter
import models.builder as model_builder
import utils

model_names = sorted(name for name in models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models.__dict__[name]))
model_names += ['resnet18_cifar_variant1']

args = argparse.Namespace(
    dataset='cifar10',
    arch='resnet18_cifar_variant1',
    workers=32,
    epochs=100,
    start_epoch=0,
    batch_size=256,
    lr=30.0,
    schedule=[60, 80],
    momentum=0.9,
    weight_decay=0.0,
    print_freq=10,
    evaluate=False,
    world_size=-1,
    rank=-1,
    dist_url='tcp://224.66.41.62:23456',
    dist_backend='nccl',
    seed=None,
    gpu=None,
    multiprocessing_distributed=False,
    opt='sgd',
    dir='log/spectral/completed-2023-05-13spectral-resnet18-mlp1000-norelu-cifar10-lr003-mu1-log_freq:20',
    num_per_class=int(1e10),
    val_every=5,
    latest_only=True,
    mpd=False,
    dist_url_add=0,
    specific_ckpts=None,
    use_random_labels=False,
    normalize=False,
    nomlp=True,
    aug='standard'
)


In [14]:
def main():
#     args = parser.parse_args()
    if args.mpd:
        args.multiprocessing_distributed = True
        args.world_size = 1
        args.rank = 0
        args.dist_url = 'tcp://127.0.0.1:' + str(10001 + args.dist_url_add)
    utils.spawn_processes(main_worker, args)


def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu
    # suppress printing if not master
    if args.multiprocessing_distributed and args.gpu != 0:
        def print_pass(*args):
            pass
        builtins.print = print_pass

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        utils.init_proc_group(args, ngpus_per_node)
    
    logger = ExperimentLogWriter(args.dir)

    # loop through checkpoints and set pre-trained
    ckpt_dir = os.path.join(args.dir, 'checkpoints')
    for fname in sorted(os.listdir(ckpt_dir)):
        if args.latest_only and not fname.startswith('latest_'): continue
        if args.specific_ckpts is not None and fname not in args.specific_ckpts: continue
        args.pretrained = os.path.join(ckpt_dir, fname)

        lineval_dir = os.path.join(args.dir, 'lin_eval_ckpt')
        if os.path.exists(lineval_dir):
            print('linear evaluation dir exists at {}, may overwrite...'.format(lineval_dir))
        eval_ckpt(
            copy.deepcopy(args), # because args.batch_size and args.workers are changed
            ngpus_per_node,
            fname,
            logger)

In [15]:
def get_embeddings(model, dataloader, device):
    # Print all layers in the original model
    print("Original model layers:")
    for idx, layer in enumerate(model.children()):
        print(f"Layer {idx}: {layer}")
    print()

    # Copy the model without its last layer
    embedding_model = nn.Sequential(*(list(model.children())[:-1]))
    embedding_model = embedding_model.to(device)

    # Print all layers in the embedding model
    print("Embedding model layers:")
    for idx, layer in enumerate(embedding_model.children()):
        print(f"Layer {idx}: {layer}")

    # Set the model in evaluation mode
    embedding_model.eval()

    # Initialize an empty tensor to store all the embeddings
    all_embeddings = torch.empty((0, 512)).to(device)
    
    return 0

    with torch.no_grad():
        for i, (images, _) in enumerate(dataloader):
            images = images.to(device)
            
            print(f"{images.shape = }")

            # Get embeddings for this batch and flatten them
            print(f"{embedding_model(images).shape = }")
            embeddings = embedding_model(images).view(images.size(0), -1)
            
            print(f"{embeddings.shape = }")

            # Concatenate with the previous embeddings
            all_embeddings = torch.cat((all_embeddings, embeddings))

    return all_embeddings


In [18]:
def eval_ckpt(args, ngpus_per_node, ptrain_fname, logger):
    # create model
    pretrained_id = ptrain_fname.split('.')[0]
    dict_id = pretrained_id + '_lineval'
    dict_id += '_{}_lr:{}_wd:{}_{}eps'.format(args.opt, args.lr, args.weight_decay, args.epochs)
    if args.nomlp:
        dict_id = dict_id + '_nomlp'
    dict_id += '_aug:' + args.aug
    dict_id += '_random_labels' if args.use_random_labels else ''
    ckpt_dir = os.path.join(args.dir, 'lin_eval_ckpt')
    os.makedirs(ckpt_dir, exist_ok=True)
    ptrain_fname += '_random_labels' if args.use_random_labels else ''
    lin_eval_loc = os.path.join(ckpt_dir, ptrain_fname)

    logger.create_data_dict(
        ['epoch', 'train_acc', 'val_acc','train_loss', 'val_loss', 'train5', 'val5'],
        dict_id=dict_id)

    model = model_builder.get_model(num_classes_dict[args.dataset], arch=args.arch)

    # freeze all layers but the last fc
    for name, param in model.named_parameters():
        if name not in ['fc.weight', 'fc.bias']:
            param.requires_grad = False
    # init the fc layer
    model.fc.weight.data.normal_(mean=0.0, std=0.01)
    model.fc.bias.data.zero_()

    # load from pre-trained, before DistributedDataParallel constructor
    if args.pretrained:
        if os.path.isfile(args.pretrained):
            checkpoint = torch.load(args.pretrained, map_location='cpu')
            state_dict = checkpoint['state_dict']
            model_builder.load_checkpoint(model, state_dict, args.pretrained, args=args, nomlp=args.nomlp)

            # Get the last layer
            last_layer_name, last_layer = list(model.named_children())[-1]
            print(f"Last layer: {last_layer_name}")
            print(f"Last layer's shape: {last_layer.weight.shape}")

            ###########################################
            # Output:
            # Last layer: fc
            # Last layer's shape: torch.Size([10, 512])
            ###########################################

        else:
            print("=> no checkpoint found at '{}'".format(args.pretrained))
    
    model = utils.init_data_parallel(args, model, ngpus_per_node)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    # optimize only the linear classifier
    parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
    assert len(parameters) == 2  # fc.weight, fc.bias
    if args.opt=='sgd':
        optimizer = torch.optim.SGD(parameters, args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    elif args.opt=='adam':
        optimizer = torch.optim.Adam(parameters, lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=args.weight_decay)

    cudnn.benchmark = True

    # Data loading code
    if args.use_random_labels:
        random_labels = torch.load(os.path.join(args.dir, 'saved_tensors', 'random_labels.pth')).numpy()
    else:
        random_labels = None
    train_sampler, train_loader = load_train(args.dataset, args.num_per_class, args.distributed,
                                             args.batch_size, args.workers, data_aug=args.aug, random_labels=random_labels)

    embeddings = get_embeddings(model, train_loader, args.gpu)
#     print(f"Embeddings shape: {embeddings.shape}")  # Should be [60000, 512]

In [19]:
main()

linear evaluation dir exists at log/spectral/completed-2023-05-13spectral-resnet18-mlp1000-norelu-cifar10-lr003-mu1-log_freq:20/lin_eval_ckpt, may overwrite...
=> loading checkpoint 'log/spectral/completed-2023-05-13spectral-resnet18-mlp1000-norelu-cifar10-lr003-mu1-log_freq:20/checkpoints/latest_800.pth'
=> loaded pre-trained model 'log/spectral/completed-2023-05-13spectral-resnet18-mlp1000-norelu-cifar10-lr003-mu1-log_freq:20/checkpoints/latest_800.pth'
Last layer: fc
Last layer's shape: torch.Size([10, 512])
Original model layers:
Layer 0: ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      