In [1]:
import torch
from torchvision import models
from torchvision import datasets, transforms
from datasets import Split_Dataset
from tqdm import tqdm
import torch.nn.functional as F
from torch.utils.data import Subset
import numpy as np

normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
val_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize
    ])

test_dataset = datasets.ImageFolder('/gpfs/u/locker/200/CADS/datasets/ImageNet/val', transform=val_transforms)

val_dataset = Split_Dataset('/gpfs/u/locker/200/CADS/datasets/ImageNet',  \
                    f'./calib_splits/am_imagenet_5percent_val.txt',
                    transform=val_transforms)

test_loader = torch.utils.data.DataLoader(
            test_dataset, batch_size=256, shuffle=True,
            num_workers=20, pin_memory=True,
        )
val_loader = torch.utils.data.DataLoader(
            val_dataset, batch_size=256, shuffle=False,
            num_workers=20, pin_memory=True,
        )

In [8]:
indices = np.arange(len(test_dataset))
test_dataset = Subset(test_dataset, list(indices))

subset_test_loader = torch.utils.data.DataLoader(
            test_dataset, batch_size=256, shuffle=False,
            num_workers=20, pin_memory=True,
        )


In [11]:
print(subset_test_loader.dataset.dataset)

Dataset ImageFolder
    Number of datapoints: 50000
    Root location: /gpfs/u/locker/200/CADS/datasets/ImageNet/val
    StandardTransform
Transform: Compose(
               Resize(size=256, interpolation=PIL.Image.BILINEAR)
               CenterCrop(size=(224, 224))
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )


In [10]:
subset_test_loader.dataset.dataset.transform

Compose(
    Resize(size=256, interpolation=PIL.Image.BILINEAR)
    CenterCrop(size=(224, 224))
    ToTensor()
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)

In [21]:
import os
split_idx_path = './misc/imagenet_train_val_split.pth'
if os.path.exists(split_idx_path):
    split_idx = torch.load(split_idx_path)
else:
    raise Exception('Imagenet train-val split dict file does not exist! git pull ')
test_dataset = Subset(test_dataset, split_idx[0]['val'])

In [22]:
subset_test_loader = torch.utils.data.DataLoader(
            test_dataset, batch_size=256, shuffle=False,
            num_workers=20, pin_memory=True,
        )

In [26]:
print(test_dataset.dataset.dataset)

Dataset ImageFolder
    Number of datapoints: 50000
    Root location: /gpfs/u/locker/200/CADS/datasets/ImageNet/val
    StandardTransform
Transform: Compose(
               Resize(size=256, interpolation=PIL.Image.BILINEAR)
               CenterCrop(size=(224, 224))
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )


In [3]:
def ensem_pred(outputs, mode='DE', num_ensem=3, target=None):
    output = outputs.softmax(dim=-1).mean(dim=0)
    _, ensem_preds = output.max(1)
    batch_acc = (ensem_preds.long() == target).sum()/ len(target)
    conf, preds = outputs.max(dim=-1) # each has dim (M,B)
    pred_exist = (preds == target).sum(dim=0).bool()
    acc_ub = 1 - (~pred_exist).sum()/len(pred_exist)

    return ensem_preds, target, output, acc_ub * 100

In [4]:
model1 = models.resnet50().cuda()
model2 = models.resnet50().cuda()
model3 = models.resnet50().cuda()

sd = torch.load("./dist_models/ft95perc_baseR_cos_lr0.003_bs256/checkpoint_best.pth", map_location="cpu")
ckpt = {k.replace("members.0.",""):v for k,v in sd['model'].items()}
model1.load_state_dict(ckpt)
model1.eval()

sd = torch.load("./dist_models/ft95perc_eqR_cos_lr0.003_bs256/checkpoint_best.pth", map_location="cpu")
ckpt = {k.replace("members.0.",""):v for k,v in sd['model'].items()}
model2.load_state_dict(ckpt)
model2.eval()

sd = torch.load("./dist_models/ft95perc_inv_cos_lr0.003_bs256/checkpoint_best.pth", map_location="cpu")
ckpt = {k.replace("members.0.",""):v for k,v in sd['model'].items()}
model3.load_state_dict(ckpt)
model3.eval()

gate = models.resnet18(num_classes=3).cuda()

In [5]:
w_acc = 0
n_acc = 0

targets = []
for it, (img,target) in enumerate(val_loader):
    target = target.cuda(non_blocking=True)
    img = img.cuda(non_blocking=True)
    with torch.no_grad():
        output1 = model1(img)
        output2 = model2(img)
        output3 = model3(img)
        preds = torch.stack([output1,output2,output3])
        _, all_preds = preds.max(-1)
        label_matrix = (all_preds == target).float().T
        logit = label_matrix.T.unsqueeze(2).repeat(1,1,1000) * preds.softmax(dim=-1)
        weighted_ensem = logit.sum(dim=0)
        naive_ensem = preds.softmax(dim=-1).mean(dim=0)
        _, w_ensem_pred = weighted_ensem.max(-1)
        _, n_ensem_pred = naive_ensem.max(-1)
        w_acc += (w_ensem_pred == target).sum()
        n_acc += (n_ensem_pred == target).sum()
#     print(w_acc, n_acc)
    targets.append(label_matrix)
    
print(w_acc/len(val_dataset), n_acc/len(val_dataset))
print(len(targets), targets[0].shape, targets[-1].shape)
all_targets = torch.cat(targets)
print(all_targets.shape)
num_correct = all_targets.sum(-1)
print((num_correct == 3).sum())
print((num_correct == 2).sum())
print((num_correct == 1).sum())
print((num_correct == 0).sum())


tensor(0.8626, device='cuda:0') tensor(0.8148, device='cuda:0')
251 torch.Size([256, 3]) torch.Size([102, 3])
torch.Size([64102, 3])
tensor(45806, device='cuda:0')
tensor(5462, device='cuda:0')
tensor(4024, device='cuda:0')
tensor(8810, device='cuda:0')


In [6]:
len(val_dataset)

64102

In [6]:
print(45806/64102)
print(5462/64102)
print(4024/64102)
print(8810/64102)

mask = (num_correct == 2) | (num_correct == 1)
print(mask.sum(), len(mask))

0.7145798883030171
0.08520794982995851
0.0627749524195813
0.13743720944744314
tensor(9486, device='cuda:0') 64102


In [8]:
indices = torch.where(mask == True)[0].cpu().numpy()
print(indices)
np.random.shuffle(indices)
print(indices)
val_ds = Split_Dataset('/gpfs/u/locker/200/CADS/datasets/ImageNet',  \
                    f'./calib_splits/am_imagenet_5percent_val.txt',
                    transform=val_transforms)
subset_val = Subset(val_ds, list(indices))

subset_val_loader = torch.utils.data.DataLoader(
            subset_val, batch_size=256, shuffle=False,
            num_workers=20, pin_memory=True,
        )


[   77    78    96 ... 64089 64090 64096]
[49883 60066 53164 ...  2996 45079  4174]


In [178]:
w_acc = 0
n_acc = 0

targets = []
for it, (img, target) in enumerate(subset_val_loader):
    target = target.cuda(non_blocking=True)
    img = img.cuda(non_blocking=True)
    with torch.no_grad():
        output1 = model1(img)
        output2 = model2(img)
        output3 = model3(img)
        preds = torch.stack([output1,output2,output3])
        _, all_preds = preds.max(-1)
        label_matrix = (all_preds == target).float().T
        logit = label_matrix.T.unsqueeze(2).repeat(1,1,1000) * preds.softmax(dim=-1)
        weighted_ensem = logit.sum(dim=0)
        naive_ensem = preds.softmax(dim=-1).mean(dim=0)
        _, w_ensem_pred = weighted_ensem.max(-1)
        _, n_ensem_pred = naive_ensem.max(-1)
        w_acc += (w_ensem_pred == target).sum()
        n_acc += (n_ensem_pred == target).sum()
    targets.append(label_matrix)
    
print(w_acc/len(val_dataset), n_acc/len(val_dataset))
print(len(targets), targets[0].shape, targets[-1].shape)
all_targets = torch.cat(targets)
print(all_targets.shape)
num_correct = all_targets.sum(-1)
print((num_correct == 3).sum())
print((num_correct == 2).sum())
print((num_correct == 1).sum())
print((num_correct == 0).sum())

tensor(0.1480, device='cuda:0') tensor(0.0998, device='cuda:0')
38 torch.Size([256, 3]) torch.Size([14, 3])
torch.Size([9486, 3])
tensor(0, device='cuda:0')
tensor(5462, device='cuda:0')
tensor(4024, device='cuda:0')
tensor(0, device='cuda:0')


In [191]:
gate = models.resnet18(num_classes=3).cuda()

optimizer = torch.optim.Adam(gate.parameters(), lr=0.5, weight_decay=0.)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, \
        patience=1, verbose=True, threshold=0.01, factor = 0.3)

In [192]:
# train loop
w_acc = 0
n_acc = 0

for ep in range(20):
    total_loss = 0.
    for it, data in enumerate(zip(subset_val_loader,targets)):
        img = (data[0][0]).cuda(non_blocking=True)
        label_matrix = (data[1]).cuda(non_blocking=True)
        label_matrix = F.normalize(label_matrix, p =1)
        gate_out = gate(img)
        gate_out = F.normalize(gate_out, p =1)
        loss = F.mse_loss(gate_out,label_matrix)
        total_loss += loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if it % 10 == 0:
            print(f"Epoch {ep} | Step {it}/{len(subset_val_loader)}: {loss.item()}")
    total_loss /= len(subset_val_loader)
    scheduler.step(total_loss)
    print("out:", gate_out[:3])
    print("target:", label_matrix[:3])
    
    print(f"Epoch {ep} completed. Avg loss: {total_loss}")
        

Epoch 0 | Step 0/38: 0.33700257539749146
Epoch 0 | Step 10/38: 0.12781676650047302
Epoch 0 | Step 20/38: 0.12650315463542938
Epoch 0 | Step 30/38: 0.12327754497528076
out: tensor([[0.3328, 0.3340, 0.3332],
        [0.3328, 0.3340, 0.3332],
        [0.3327, 0.3340, 0.3333],
        [0.3328, 0.3340, 0.3332],
        [0.3329, 0.3340, 0.3331],
        [0.3331, 0.3341, 0.3328],
        [0.3323, 0.3334, 0.3343],
        [0.3328, 0.3340, 0.3332],
        [0.3328, 0.3340, 0.3332],
        [0.3329, 0.3340, 0.3331]], device='cuda:0', grad_fn=<SliceBackward>)
target: tensor([[0.5000, 0.5000, 0.0000],
        [0.0000, 0.5000, 0.5000],
        [0.5000, 0.0000, 0.5000],
        [0.0000, 1.0000, 0.0000],
        [0.0000, 1.0000, 0.0000],
        [0.0000, 1.0000, 0.0000],
        [1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.0000, 1.0000, 0.0000],
        [0.0000, 1.0000, 0.0000]], device='cuda:0')
Epoch 0 completed. Avg loss: 0.1321294754743576
Epoch 1 | Step 0/38: 0.12518744

KeyboardInterrupt: 

In [None]:
gate = models.resnet18(num_classes=3).cuda()

optimizer = torch.optim.Adam(gate.parameters(), lr=0.5, weight_decay=0.)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, \
        patience=1, verbose=True, threshold=0.01, factor = 0.3)

In [190]:
### BCE LOSS ##

# train loop
w_acc = 0
n_acc = 0
criterion = torch.nn.BCELoss()
for ep in range(100):
    total_loss = 0.
    for it, data in enumerate(zip(subset_val_loader,targets)):
        img = (data[0][0]).cuda(non_blocking=True)
        label_matrix = (data[1]).cuda(non_blocking=True)
        gate_out = gate(img)
        sig_out = torch.sigmoid(gate_out)
        loss = criterion(sig_out,label_matrix)
        total_loss += loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if it % 10 == 0:
            print(f"Epoch {ep} | Step {it}/{len(val_loader)}: {loss.item()}")
    total_loss /= len(val_loader)
    scheduler.step(total_loss)
    print("out:", sig_out[:3])
    print("target:", label_matrix[:3])
    
    print(f"Epoch {ep} completed. Avg loss: {total_loss}")


Epoch 0 | Step 0/40: 47.265625
Epoch 0 | Step 10/40: 47.786460876464844
Epoch 0 | Step 20/40: 47.52604293823242
Epoch 0 | Step 30/40: 46.875
out: tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]], device='cuda:0', grad_fn=<SliceBackward>)
target: tensor([[1., 1., 0.],
        [0., 1., 1.],
        [1., 0., 1.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [1., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.]], device='cuda:0')
Epoch 0 completed. Avg loss: 45.15951156616211
Epoch 1 | Step 0/40: 47.265625
Epoch 1 | Step 10/40: 47.786460876464844
Epoch 1 | Step 20/40: 47.52604293823242
Epoch 1 | Step 30/40: 46.875
Epoch    12: reducing learning rate of group 0 to 1.2150e-03.
out: tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
  

Epoch 12 | Step 0/40: 47.265625
Epoch 12 | Step 10/40: 47.786460876464844
Epoch 12 | Step 20/40: 47.52604293823242
Epoch 12 | Step 30/40: 46.875
out: tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]], device='cuda:0', grad_fn=<SliceBackward>)
target: tensor([[1., 1., 0.],
        [0., 1., 1.],
        [1., 0., 1.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [1., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.]], device='cuda:0')
Epoch 12 completed. Avg loss: 45.15951156616211
Epoch 13 | Step 0/40: 47.265625
Epoch 13 | Step 10/40: 47.786460876464844
Epoch 13 | Step 20/40: 47.52604293823242
Epoch 13 | Step 30/40: 46.875
Epoch    24: reducing learning rate of group 0 to 8.8573e-07.
out: tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1.

KeyboardInterrupt: 

In [12]:
val_ds = Split_Dataset('/gpfs/u/locker/200/CADS/datasets/ImageNet',  \
                    f'./calib_splits/am_imagenet_5percent_val.txt',
                    transform=val_transforms)
subsubset_val = Subset(val_ds, list(indices)[:1000])

subsubset_val_loader = torch.utils.data.DataLoader(
            subsubset_val, batch_size=256, shuffle=False,
            num_workers=20, pin_memory=True,
        )

w_acc = 0
n_acc = 0

targets = []
for it, (img, target) in enumerate(subsubset_val_loader):
    target = target.cuda(non_blocking=True)
    img = img.cuda(non_blocking=True)
    with torch.no_grad():
        output1 = model1(img)
        output2 = model2(img)
        output3 = model3(img)
        preds = torch.stack([output1,output2,output3])
        _, all_preds = preds.max(-1)
        label_matrix = (all_preds == target).float().T
    targets.append(label_matrix)
    
print(len(targets), targets[0].shape, targets[-1].shape)
all_targets = torch.cat(targets)
num_correct = all_targets.sum(-1)
print(all_targets.shape, (num_correct == 3).sum(), (num_correct == 2).sum(), (num_correct == 1).sum(), (num_correct == 0).sum())

gate = models.resnet18(num_classes=3).cuda()
optimizer = torch.optim.Adam(gate.parameters(), lr=0.5, weight_decay=0.)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, \
        patience=1, verbose=True, threshold=0.01, factor = 0.3)

### BCE LOSS ##

# train loop
w_acc = 0
n_acc = 0
criterion = torch.nn.BCELoss()
for ep in range(100):
    total_loss = 0.
    for it, data in enumerate(zip(subsubset_val_loader,targets)):
        img = (data[0][0]).cuda(non_blocking=True)
        label_matrix = (data[1]).cuda(non_blocking=True)
        gate_out = gate(img)
        sig_out = torch.sigmoid(gate_out)
        loss = criterion(sig_out,label_matrix)
        total_loss += loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if it % 10 == 0:
            print(f"Epoch {ep} | Step {it}/{len(subsubset_val_loader)}: {loss.item()}")
    total_loss /= len(subsubset_val_loader)
    scheduler.step(total_loss)
    print("out:", sig_out[:10])
    print("target:", label_matrix[:10])
    
    print(f"Epoch {ep} completed. Avg loss: {total_loss}")



4 torch.Size([256, 3]) torch.Size([232, 3])
torch.Size([1000, 3]) tensor(0, device='cuda:0') tensor(568, device='cuda:0') tensor(432, device='cuda:0') tensor(0, device='cuda:0')
Epoch 0 | Step 0/4: 0.7243248224258423
out: tensor([[0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.]], device='cuda:0', grad_fn=<SliceBackward>)
target: tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 0., 0.],
        [1., 0., 1.],
        [0., 1., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 1., 1.],
        [1., 1., 0.],
        [1., 1., 0.]], device='cuda:0')
Epoch 0 completed. Avg loss: 36.323997497558594
Epoch 1 | Step 0/4: 48.567710876464844
out: tensor([[0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
 

Epoch 13 | Step 0/4: 48.567710876464844
out: tensor([[0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.]], device='cuda:0', grad_fn=<SliceBackward>)
target: tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 0., 0.],
        [1., 0., 1.],
        [0., 1., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 1., 1.],
        [1., 1., 0.],
        [1., 1., 0.]], device='cuda:0')
Epoch 13 completed. Avg loss: 48.28484344482422
Epoch 14 | Step 0/4: 48.567710876464844
Epoch    15: reducing learning rate of group 0 to 1.0935e-04.
out: tensor([[0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.]], device='cuda:0', grad_fn=<SliceBackward>)
target: tensor([[1., 0., 0.]

Epoch 26 | Step 0/4: 48.567710876464844
Epoch    27: reducing learning rate of group 0 to 7.9716e-08.
out: tensor([[0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.]], device='cuda:0', grad_fn=<SliceBackward>)
target: tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 0., 0.],
        [1., 0., 1.],
        [0., 1., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 1., 1.],
        [1., 1., 0.],
        [1., 1., 0.]], device='cuda:0')
Epoch 26 completed. Avg loss: 48.28484344482422
Epoch 27 | Step 0/4: 48.567710876464844
out: tensor([[0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.]], device='cuda:0', grad_fn=<SliceBackward>)
target: tensor([[1., 0., 0.]

Epoch 40 | Step 0/4: 48.567710876464844
out: tensor([[0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.]], device='cuda:0', grad_fn=<SliceBackward>)
target: tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 0., 0.],
        [1., 0., 1.],
        [0., 1., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 1., 1.],
        [1., 1., 0.],
        [1., 1., 0.]], device='cuda:0')
Epoch 40 completed. Avg loss: 48.28484344482422
Epoch 41 | Step 0/4: 48.567710876464844
out: tensor([[0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.]], device='cuda:0', grad_fn=<SliceBackward>)
target: tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 0., 0.],
        [1., 0.,

KeyboardInterrupt: 

In [10]:
sd = torch.load("./dist_models/ft95perc_baseR_cos_lr0.003_bs256/checkpoint_best.pth", map_location="cpu")
ckpt = {k.replace("members.0.",""):v for k,v in sd['model'].items()}
model1.load_state_dict(ckpt)
model1.eval()

sd = torch.load("./dist_models/ft95perc_eqR_cos_lr0.003_bs256/checkpoint_best.pth", map_location="cpu")
ckpt = {k.replace("members.0.",""):v for k,v in sd['model'].items()}
model2.load_state_dict(ckpt)
model2.eval()

sd = torch.load("./dist_models/ft95perc_inv_cos_lr0.003_bs256/checkpoint_best.pth", map_location="cpu")
ckpt = {k.replace("members.0.",""):v for k,v in sd['model'].items()}
model3.load_state_dict(ckpt)
model3.eval()
model1=model1.cuda()
model2=model2.cuda()
model3=model3.cuda()


In [19]:
w_acc = 0
n_acc = 0
with torch.no_grad():
    for img, target in val_loader:
        target = target.cuda(non_blocking=True)
        output1 = model1(img.cuda(non_blocking=True))
        output2 = model2(img.cuda(non_blocking=True))
        output3 = model3(img.cuda(non_blocking=True))
        preds = torch.stack([output1,output2,output3])
#         print(preds.shape)
        _, all_preds = preds.max(-1)
#         pred_exist = (all_preds == target)
#         print(pred_exist.shape)
        label_matrix = (all_preds == target).float().T
        logit = label_matrix.T.unsqueeze(2).repeat(1,1,1000) * preds.softmax(dim=-1)
        weighted_ensem = logit.sum(dim=0)
        naive_ensem = preds.softmax(dim=-1).mean(dim=0)
        _, w_ensem_pred = weighted_ensem.max(-1)
        _, n_ensem_pred = naive_ensem.max(-1)
        
        w_acc += (w_ensem_pred == target).sum()
        n_acc += (n_ensem_pred == target).sum()
print(w_acc/len(val_dataset), n_acc/len(val_dataset))
        

tensor(236, device='cuda:0') tensor(223, device='cuda:0')
tensor(456, device='cuda:0') tensor(431, device='cuda:0')
tensor(683, device='cuda:0') tensor(652, device='cuda:0')
tensor(904, device='cuda:0') tensor(862, device='cuda:0')
tensor(1130, device='cuda:0') tensor(1076, device='cuda:0')
tensor(1357, device='cuda:0') tensor(1290, device='cuda:0')
tensor(1572, device='cuda:0') tensor(1492, device='cuda:0')
tensor(1799, device='cuda:0') tensor(1706, device='cuda:0')
tensor(2028, device='cuda:0') tensor(1923, device='cuda:0')
tensor(2246, device='cuda:0') tensor(2126, device='cuda:0')
tensor(2470, device='cuda:0') tensor(2341, device='cuda:0')
tensor(2683, device='cuda:0') tensor(2548, device='cuda:0')
tensor(2903, device='cuda:0') tensor(2752, device='cuda:0')
tensor(3119, device='cuda:0') tensor(2957, device='cuda:0')
tensor(3336, device='cuda:0') tensor(3166, device='cuda:0')
tensor(3555, device='cuda:0') tensor(3371, device='cuda:0')
tensor(3776, device='cuda:0') tensor(3578, devic

tensor(29975, device='cuda:0') tensor(28365, device='cuda:0')
tensor(30202, device='cuda:0') tensor(28576, device='cuda:0')
tensor(30431, device='cuda:0') tensor(28789, device='cuda:0')
tensor(30648, device='cuda:0') tensor(29000, device='cuda:0')
tensor(30867, device='cuda:0') tensor(29210, device='cuda:0')
tensor(31090, device='cuda:0') tensor(29421, device='cuda:0')
tensor(31314, device='cuda:0') tensor(29632, device='cuda:0')
tensor(31536, device='cuda:0') tensor(29840, device='cuda:0')
tensor(31752, device='cuda:0') tensor(30044, device='cuda:0')
tensor(31970, device='cuda:0') tensor(30253, device='cuda:0')
tensor(32197, device='cuda:0') tensor(30466, device='cuda:0')
tensor(32421, device='cuda:0') tensor(30676, device='cuda:0')
tensor(32644, device='cuda:0') tensor(30890, device='cuda:0')
tensor(32864, device='cuda:0') tensor(31097, device='cuda:0')
tensor(33079, device='cuda:0') tensor(31300, device='cuda:0')
tensor(33306, device='cuda:0') tensor(31514, device='cuda:0')
tensor(3

In [20]:
55587/len(val_dataset)

0.8671648310505132

In [21]:
52604/len(val_dataset)

0.8206296215406695

In [6]:
from utils import accuracy, AverageMeter

In [7]:
w_acc = 0
n_acc = 0
acc1 = 0
acc2 = 0
acc3 = 0
m1 = AverageMeter('m1')
m2 = AverageMeter('m2')
m3 = AverageMeter('m3')

with torch.no_grad():
    for img, target in val_loader:
        target = target.cuda(non_blocking=True)
        output1 = model1(img.cuda(non_blocking=True))
        output2 = model2(img.cuda(non_blocking=True))
        output3 = model3(img.cuda(non_blocking=True))
        preds = torch.stack([output1,output2,output3])
        _, p1 = output1.max(-1)
        _, p2 = output2.max(-1)
        _, p3 = output3.max(-1)
        
        acc1 += (p1 == target).sum()
        acc2 += (p2 == target).sum()
        acc3 += (p3 == target).sum()
        acc_1, acc_5 = accuracy(output1, target, topk=(1, 5))
        m1.update(acc_1.item(),img.size(0))
        
#         print(acc_1.item(), (p1 == target).sum().item()/len(target))
        acc_1, acc_5 = accuracy(output2, target, topk=(1, 5))
        m2.update(acc_1.item(),img.size(0))
    
#         print(acc_1.item(), (p2 == target).sum().item()/len(target))
        acc_1, acc_5 = accuracy(output3, target, topk=(1, 5))
        m3.update(acc_1.item(),img.size(0))
    
#         print(acc_1.item(), (p3 == target).sum().item()/len(target))
        _, all_preds = preds.max(-1)
#         pred_exist = (all_preds == target)
#         print(pred_exist.shape)
        label_matrix = (all_preds == target).float().T
        logit = label_matrix.T.unsqueeze(2).repeat(1,1,1000) * preds.softmax(dim=-1)
        weighted_ensem = logit.sum(dim=0)
        naive_ensem = preds.softmax(dim=-1).mean(dim=0)
        _, w_ensem_pred = weighted_ensem.max(-1)
        _, n_ensem_pred = naive_ensem.max(-1)
        
        w_acc += (w_ensem_pred == target).sum()
        n_acc += (n_ensem_pred == target).sum()
print(w_acc/len(test_dataset), n_acc/len(test_dataset))
print(acc1/len(test_dataset), acc2/len(test_dataset), acc3/len(test_dataset))
print(m1.avg,m2.avg,m3.avg)

tensor(1.1059, device='cuda:0') tensor(1.0446, device='cuda:0')
tensor(1.0170, device='cuda:0') tensor(1.0232, device='cuda:0') tensor(1.0071, device='cuda:0')
79.32981809882871 79.80718230010976 78.5560512915851


In [8]:
print(w_acc/len(val_dataset), n_acc/len(val_dataset))
print(acc1/len(val_dataset), acc2/len(val_dataset), acc3/len(val_dataset))

tensor(0.8626, device='cuda:0') tensor(0.8148, device='cuda:0')
tensor(0.7933, device='cuda:0') tensor(0.7981, device='cuda:0') tensor(0.7856, device='cuda:0')


In [10]:
val_loader.dataset.transform

Compose(
    Resize(size=256, interpolation=PIL.Image.BILINEAR)
    CenterCrop(size=(224, 224))
    ToTensor()
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)

In [30]:
from networks import EnsembleSSL
from tqdm import tqdm

In [45]:
com_model = EnsembleSSL2('resnet50', 3, 1000, 'freeze').cuda()
pt_paths = ["./dist_models/ft95perc_baseR_cos_lr0.003_bs256/checkpoint_best.pth","./dist_models/ft95perc_eqR_cos_lr0.003_bs256/checkpoint_best.pth","./dist_models/ft95perc_inv_cos_lr0.003_bs256/checkpoint_best.pth"]
com_model.load_sep_weights(pt_paths)

sd = torch.load("./dist_models/ft95perc_baseR_cos_lr0.003_bs256/checkpoint_best.pth", map_location="cpu")
ckpt = {k.replace("members.0.",""):v for k,v in sd['model'].items()}
model1.load_state_dict(ckpt)
model1.eval()

sd = torch.load("./dist_models/ft95perc_eqR_cos_lr0.003_bs256/checkpoint_best.pth", map_location="cpu")
ckpt = {k.replace("members.0.",""):v for k,v in sd['model'].items()}
model2.load_state_dict(ckpt)
model2.eval()

sd = torch.load("./dist_models/ft95perc_inv_cos_lr0.003_bs256/checkpoint_best.pth", map_location="cpu")
ckpt = {k.replace("members.0.",""):v for k,v in sd['model'].items()}
model3.load_state_dict(ckpt)
model3.eval()
model1=model1.cuda()
model2=model2.cuda()
model3=model3.cuda()

initialized EnsembleSSL2
loading sep weights
missing keys []
unexpected keys []
missing keys []
unexpected keys []
missing keys []
unexpected keys []


In [57]:
w_acc = 0
n_acc = 0
com_model.eval()
with torch.no_grad():
    for i, (img, target) in enumerate(test_loader):
        target = target.cuda(non_blocking=True)
        output1 = model1(img.cuda(non_blocking=True))
        output2 = model2(img.cuda(non_blocking=True))
        output3 = model3(img.cuda(non_blocking=True))
        preds1 = torch.stack([output1,output2,output3])
        preds = com_model.forward(img.cuda(non_blocking=True))
        _, all_preds = preds.max(-1)
        _, all_preds1 = preds1.max(-1)
#         print(preds1[:,0])
#         print(preds[:,0])
        assert(torch.equal(preds1,preds))
        
        label_matrix = (all_preds == target).float().T
        logit = label_matrix.T.unsqueeze(2).repeat(1,1,1000) * preds.softmax(dim=-1)
        weighted_ensem = logit.sum(dim=0)
        naive_ensem = preds.softmax(dim=-1).mean(dim=0)
        _, w_ensem_pred = weighted_ensem.max(-1)
        _, n_ensem_pred = naive_ensem.max(-1)
        
        w_acc += (w_ensem_pred == target).sum()
        n_acc += (n_ensem_pred == target).sum()
print(w_acc/len(test_dataset), n_acc/len(test_dataset))

True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
tensor(0.8437, devic

In [44]:
models_dict = {
    'resnet18': models.resnet18,
    'resnet50': models.resnet50,
}
import torch.nn as nn
from torchvision import models
from utils import consume_prefix_in_state_dict_if_present

class EnsembleSSL2(nn.Module):
    def __init__(self, arch, num_ensem=1, num_classes=1000, eval_mode='freeze'):
        super().__init__()
        print("initialized EnsembleSSL2")
        self.num_ensem = num_ensem
        self.num_classes = num_classes
        model_fn = models_dict[arch]
        self.members = torch.nn.ModuleList([model_fn(num_classes=self.num_classes) for _ in range(num_ensem)])
        self.set_eval_mode(eval_mode)

    def load_sep_weights(self, weights_path_list):
        print("loading sep weights")
        for m in range(self.num_ensem):
            weights = weights_path_list[m]
            state_dict = torch.load(weights, map_location='cpu')

            cur_mem = self.members[m]
            if 'model' in state_dict:
                if 'members.0.fc.weight' in state_dict['model']:
                    # for LP on imagenet-100 using imagenet ckpt (i.e. different num classes)
                    if state_dict['model']['members.0.fc.weight'].shape[0] != self.num_classes:
                        print(f"model weights dim: {state_dict['model']['members.0.fc.weight'].shape[0]}, num classes: {self.num_classes}")
                        new_state_dict = {k:v for k,v in state_dict['model'].items() if 'fc' not in k}
                    else:
                        new_state_dict = state_dict['model']
                    consume_prefix_in_state_dict_if_present(new_state_dict, 'members.0.') # this is assuming only 1 member was trained at a time
                    missing_keys, unexpected_keys = cur_mem.load_state_dict(new_state_dict, strict=False)
                    print('missing keys', missing_keys)
                    print('unexpected keys', unexpected_keys)

                else:
                    consume_prefix_in_state_dict_if_present(state_dict['model'], 'members.0.') # this is assuming only 1 member was trained at a time
                    consume_prefix_in_state_dict_if_present(state_dict['model'], 'module.backbone.')
                    missing_keys, unexpected_keys = cur_mem.load_state_dict(state_dict['model'], strict=False)
                    print('missing keys', missing_keys)
                    print('unexpected keys', unexpected_keys)
                    print("===> Loaded backbone state dict from ", weights)

            elif 'backbone' in state_dict:
                missing_keys, unexpected_keys = cur_mem.load_state_dict(state_dict["backbone"], strict=False)
                print('missing keys', missing_keys)
                print('unexpected keys', unexpected_keys)
            else:
                print(state_dict.keys())

            if 'log_reg_weight' in state_dict and 'log_reg_bias' in state_dict:
                cur_mem.fc.weight = torch.from_numpy(state_dict['log_reg_weight'])
                cur_mem.fc.bias = torch.from_numpy(state_dict['log_reg_bias'])


    def load_weights(self, weights_path_list, convert_from_single=False):
        # convert_from_single: whether to convert the weights from 1 model (MultiBackbone)
        # making sure that the number of pretrained weights & ensem member size is equal
        print("loading weights")
        if not convert_from_single:
            assert len(self.members) == len(weights_path_list)
        else:
            assert len(weights_path_list) == 1
            ensem_state_dict = torch.load(weights_path_list[0], map_location='cpu')
            # ensem_state_dict = ensem_state_dict['enc_state_dict']
            state_dicts = convert_weights_from_single_backbone(ensem_state_dict, self.num_ensem)

        for m in range(self.num_ensem):
            cur_mem = self.members[m]

            if convert_from_single:
                weights = ''
                state_dict = state_dicts[m]
            else:
                weights = weights_path_list[m]
                state_dict = torch.load(weights, map_location='cpu')

            if 'epoch' in str(weights):
                consume_prefix_in_state_dict_if_present(state_dict['model'], 'module.backbone.')
                cur_mem.load_state_dict(state_dict['model'], strict=False)
            elif 'simsiam' in str(weights):
                consume_prefix_in_state_dict_if_present(state_dict['state_dict'], 'module.backbone.')
                consume_prefix_in_state_dict_if_present(state_dict['state_dict'], 'backbone.')
                missing_keys, unexpected_keys = cur_mem.load_state_dict(state_dict["state_dict"], strict=False)
                print('missing keys', missing_keys)
                print('unexpected keys', unexpected_keys)
            else:
                if 'model' in state_dict:
                    missing_keys, unexpected_keys = cur_mem.load_state_dict(state_dict["model"], strict=False)
                elif 'backbone' in state_dict:
                    missing_keys, unexpected_keys = cur_mem.load_state_dict(state_dict["backbone"], strict=False)
                print('missing keys', missing_keys)
                print('unexpected keys', unexpected_keys)
                # assert missing_keys == ['fc.weight', 'fc.bias'] and unexpected_keys == []

            if self.eval_mode in {'linear_probe', 'finetune'}:
                cur_mem.fc.weight.data.normal_(mean=0.0, std=0.01)
                cur_mem.fc.bias.data.zero_()

    def set_eval_mode(self, mode='freeze'):
        self.eval_mode = mode
        if self.eval_mode == 'freeze':
            for cur_mem in self.members:
                cur_mem.requires_grad_(False)
                cur_mem.fc.requires_grad_(False)
        elif self.eval_mode == 'linear_probe':
            for cur_mem in self.members:
                cur_mem.requires_grad_(False)
                cur_mem.fc.requires_grad_(True)
        elif self.eval_mode == 'finetune':
            for cur_mem in self.members:
                cur_mem.requires_grad_(True)
                cur_mem.fc.requires_grad_(True)
        elif self.eval_mode == 'log_reg':
            for cur_mem in self.members:
                cur_mem.fc = nn.Identity()
                cur_mem.requires_grad_(False)
                cur_mem.fc.requires_grad_(False)
        elif self.eval_mode == 'extract_features':
            for cur_mem in self.members:
                # replacing all members' fc layers with identity to extract features
                cur_mem.fc = nn.Identity()

                cur_mem.requires_grad_(False)
                cur_mem.fc.requires_grad_(False)
        else:
            raise NotImplementedError(f'Evaluation mode {mode} not implemented')


    # forward that feeds inputs to the ensemble and returns stacked outputs
    def forward(self, x, gate_cond=None):
        outputs = []

        for cur_mem in self.members:
            output = cur_mem(x)
            outputs.append(output)

        return torch.stack(outputs)

    # forward that feeds inputs to the ensemble and returns only the averaged outputs from the ensemble
    def forward_ensem(self, x):
        outputs = self.forward(x)

        return outputs.softmax(dim=-1).mean(dim=0)
