-
Notifications
You must be signed in to change notification settings - Fork 7
/
utils.py
69 lines (56 loc) · 1.82 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from __future__ import print_function
import torch
import numpy as np
from torchvision import transforms
import torchvision
import os
import pickle
import matplotlib.pyplot as plt
from PIL import Image
import importlib
from DeepDIGCode import config
args= config.args
data_dir = args.project_dir+'/Data/MNIST/'
transform = transforms.ToTensor()
def get_original_data(split='train'):
if split =='test':
data_name_file = 'test.pt'
elif split =='train':
data_name_file = 'training.pt'
_data, labels = torch.load(data_dir+ data_name_file)
data = []
for img in _data:
data.append(transform(Image.fromarray(img.numpy(), mode='L')))
data= torch.stack(data)
return data, labels
def get_class_specific_data(Class, split='train'):
if split =='test':
data_name_file = 'test.pt'
elif split =='train':
data_name_file = 'training.pt'
_data, _labels = torch.load(data_dir+ data_name_file)
data, labels =[],[]
for i, (label, sample) in enumerate(zip(_labels,_data)):
if label == Class:
data.append(transform(Image.fromarray(sample.numpy(), mode='L')))
labels.append(label)
data = torch.stack(data)
labels = torch.stack(labels)
return data, labels
def imshow(img,fname,show=True,title=""):
img = torchvision.utils.make_grid(img.data)
npimg = img.detach().cpu().numpy()
plt.axis('off')
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.title(title)
if show:
plt.show()
else:
plt.savefig(fname)
def save_samples(dir,samples,filename,show=False):
if not os.path.exists(dir):
os.makedirs(dir)
for j,i in enumerate(range(0,len(samples),50)):
start = i
end = min(start+50,len(samples))
imshow(samples[start:end], dir + filename + '_{}.png'.format(j), show=show)