In [None]:
import torch
import torchvision 
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.models.mobilenet import mobilenet_v2
from torch.nn import CrossEntropyLoss
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
training_dataset = datasets.MNIST(root='mnist/', train=True,download=True, transform=torchvision.transforms.ToTensor())
test_dataset = datasets.MNIST(root='mnist/', train=False, transform=torchvision.transforms.ToTensor())

In [None]:
full_dataset = training_dataset+test_dataset
type(full_dataset)

torch.utils.data.dataset.ConcatDataset

In [None]:
mnist_dataloader = DataLoader(full_dataset, batch_size=1000, shuffle=True)
type(mnist_dataloader)

torch.utils.data.dataloader.DataLoader

In [75]:
class TopKDiscriminator:
  def __init__(self, model, device, dataloader, c_in_model, c_in_data, batch_size):
    self.model = model
    self.device = device
    self.dataloader = dataloader
    self.c_in_model = c_in_model
    self.c_in_data = c_in_data
    self.batch_size = batch_size
  
  def fit(self, loss_fn=CrossEntropyLoss(), learning_rate=0.01, epochs=10):
    self.model.to(device)
    self.model.train()
    opt = torch.optim.Adam(self.model.parameters(), lr=learning_rate)

    for epoch in tqdm(range(epochs)):
      correct = 0
      for batch_idx, (data, label) in enumerate(self.dataloader):
        if self.c_in_model != self.c_in_data:
          data = data.repeat(1,self.c_in_model,1,1)
        
        data, label = data.to(self.device), label.to(self.device)
        opt.zero_grad()
        y_hat = self.model(data)
        loss = loss_fn(y_hat, label)
        loss.backward()
        opt.step()
        pred = y_hat.argmax(dim=1, keepdim=True)
        results = label.eq(pred.view_as(label))
        correct += results.sum().item()

      acc = correct / len(self.dataloader.dataset)
      print("\n Accuracy this epoch = {}".format(acc))
  
  def predict(self, testloader, k=1):
    self.model.eval()
    correct = 0
    error_idx = []

    with torch.no_grad():
      for batch_idx, (data, label) in enumerate(testloader):
        if self.c_in_model != self.c_in_data:
          data = data.repeat(1,self.c_in_model,1,1)

        data, label = data.to(self.device), label.to(self.device)
        y_hat = self.model(data)
        #pred = y_hat.argmax(dim=1, keepdim=True)
        results = torch.Tensor([True if x in y else False for x,y in zip(label, y_hat.topk(k=k).indices)]).type(torch.bool)
        incorrect_ids = [batch_idx*self.batch_size+id for id in range(len(results)) if results[id] == False]
        error_idx += incorrect_ids
        correct += results.sum().item()
    
    acc = correct / len(testloader.dataset)
    print("\n Accuracy = {}/{} , {}".format(correct,len(testloader.dataset),acc))

    return error_idx

In [76]:
model = mobilenet_v2(pretrained=True)
#for param in model.parameters():
#    param.requires_grad = False
model.classifier[1] = torch.nn.Linear(in_features=model.classifier[1].in_features, out_features=10)



In [77]:
discriminator = TopKDiscriminator(model, device, mnist_dataloader, 3, 1, 1000)

In [78]:
discriminator.fit(CrossEntropyLoss(), 0.001, 20)

  5%|▌         | 1/20 [00:17<05:33, 17.56s/it]


 Accuracy this epoch = 0.9455714285714286


 10%|█         | 2/20 [00:34<05:08, 17.16s/it]


 Accuracy this epoch = 0.9908571428571429


 15%|█▌        | 3/20 [00:52<04:55, 17.41s/it]


 Accuracy this epoch = 0.9928142857142858


 20%|██        | 4/20 [01:09<04:41, 17.58s/it]


 Accuracy this epoch = 0.9943285714285715


 25%|██▌       | 5/20 [01:26<04:18, 17.25s/it]


 Accuracy this epoch = 0.9957285714285714


 30%|███       | 6/20 [01:43<04:00, 17.20s/it]


 Accuracy this epoch = 0.9956857142857143


 35%|███▌      | 7/20 [02:00<03:41, 17.07s/it]


 Accuracy this epoch = 0.9967


 40%|████      | 8/20 [02:17<03:23, 16.98s/it]


 Accuracy this epoch = 0.9969857142857143


 45%|████▌     | 9/20 [02:34<03:07, 17.07s/it]


 Accuracy this epoch = 0.9974571428571428


 50%|█████     | 10/20 [02:51<02:49, 16.98s/it]


 Accuracy this epoch = 0.9973285714285715


 55%|█████▌    | 11/20 [03:08<02:33, 17.07s/it]


 Accuracy this epoch = 0.9978428571428571


 60%|██████    | 12/20 [03:25<02:16, 17.01s/it]


 Accuracy this epoch = 0.9975428571428572


 65%|██████▌   | 13/20 [03:42<01:58, 17.00s/it]


 Accuracy this epoch = 0.9975857142857143


 70%|███████   | 14/20 [03:59<01:42, 17.03s/it]


 Accuracy this epoch = 0.9977285714285714


 75%|███████▌  | 15/20 [04:16<01:25, 17.01s/it]


 Accuracy this epoch = 0.9976714285714285


 80%|████████  | 16/20 [04:33<01:08, 17.04s/it]


 Accuracy this epoch = 0.9974428571428572


 85%|████████▌ | 17/20 [04:50<00:50, 16.95s/it]


 Accuracy this epoch = 0.9979


 90%|█████████ | 18/20 [05:07<00:33, 16.95s/it]


 Accuracy this epoch = 0.9983


 95%|█████████▌| 19/20 [05:25<00:17, 17.30s/it]


 Accuracy this epoch = 0.9989428571428571


100%|██████████| 20/20 [05:42<00:00, 17.14s/it]


 Accuracy this epoch = 0.9984142857142857





In [80]:
torch.save(model.state_dict(), 'mobilenet.pt')

In [79]:
ddpm_dataset = torch.load('tests_ddpm_with_classes.pt')
ddpm_dataloader = DataLoader(ddpm_dataset, batch_size=1000, shuffle=False)

In [83]:
ddim_dataset = torch.load('tests_ddim_with_classes.pt')
ddim_dataloader = DataLoader(ddim_dataset, batch_size=1000, shuffle=False)

In [81]:
ddpm_errors = discriminator.predict(ddpm_dataloader,k=1)


 Accuracy = 4920/5040 , 0.9761904761904762


In [84]:
ddim_errors = discriminator.predict(ddim_dataloader,k=1)


 Accuracy = 4761/5040 , 0.9446428571428571


In [85]:
ddpm_corrected_k_1 = [val for idx,val in enumerate(ddpm_dataset) if idx not in ddpm_errors]
ddim_corrected_k_1 = [val for idx,val in enumerate(ddim_dataset) if idx not in ddim_errors]

In [86]:
torch.save(ddpm_corrected_k_1, 'ddpm_corrected_k_1.pt')
torch.save(ddim_corrected_k_1, 'ddim_corrected_k_1.pt')

In [87]:
from google.colab import files
files.download('ddpm_corrected_k_1.pt')
files.download('ddim_corrected_k_1.pt')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [88]:
ddpm_errors_k_2 = discriminator.predict(ddpm_dataloader,k=2)
ddim_errors_k_2 = discriminator.predict(ddim_dataloader,k=2)


 Accuracy = 4998/5040 , 0.9916666666666667

 Accuracy = 4898/5040 , 0.9718253968253968


In [89]:
ddpm_corrected_k_2 = [val for idx,val in enumerate(ddpm_dataset) if idx not in ddpm_errors_k_2]
ddim_corrected_k_2 = [val for idx,val in enumerate(ddim_dataset) if idx not in ddim_errors_k_2]

In [None]:
torch.save(ddpm_corrected_k_1, 'ddpm_corrected_k_2.pt')
torch.save(ddim_corrected_k_1, 'ddim_corrected_k_2.pt')

In [None]:
files.download('ddpm_corrected_k_2.pt')
files.download('ddim_corrected_k_2.pt')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>