In [1]:
import numpy as np
import os 
from os import walk

def load_safari(folder):

    mypath = os.path.join("./data", folder)
    txt_name_list = []
    for (dirpath, dirnames, filenames) in walk(mypath):
        for f in filenames:
            if f != '.DS_Store':
                txt_name_list.append(f)
                break

    slice_train = int(80000/len(txt_name_list))  ###Setting value to be 80000 for the final dataset
    i = 0
    seed = np.random.randint(1, 10e6)

    for txt_name in txt_name_list:
        txt_path = os.path.join(mypath,txt_name)
        x = np.load(txt_path)
        x = (x.astype('float32') - 127.5) / 127.5
        # x = x.astype('float32') / 255.0
        
        x = x.reshape(x.shape[0], 28, 28, 1)
        
        y = [i] * len(x)  
        np.random.seed(seed)
        np.random.shuffle(x)
        np.random.seed(seed)
        np.random.shuffle(y)
        x = x[:slice_train]
        y = y[:slice_train]
        if i != 0: 
            xtotal = np.concatenate((x,xtotal), axis=0)
            ytotal = np.concatenate((y,ytotal), axis=0)
        else:
            xtotal = x
            ytotal = y
        i += 1
        
    return xtotal, ytotal

In [2]:
import torch
import torch.nn as nn

class View(nn.Module):
    def __init__(self,shape):
        super().__init__()
        self.shape = shape,
    def forward(self,x):
        return x.view(*self.shape)

In [7]:
class D(nn.Module):
    def __init__(self,dim_latent=200):
        # 파이토치 부모 클래스 초기화
        super().__init__()
        self.dim_latent = dim_latent
       # 신경망 레이어 정의
        self.model = nn.Sequential(
            #(-1,1,28,28)
            nn.Conv2d(1, 64, kernel_size=7,stride=2,padding=3),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.2),
            #(-1,64,14,14)
            
            nn.Conv2d(64, 64, kernel_size=5, stride=2,padding=2),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.2),
            #(-1,64,7,7)

            nn.Conv2d(64, 128, kernel_size=4),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.2),
            #(-1,128,4,4)
            
            nn.Conv2d(128, 128, kernel_size=3,padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.2),
            #(-1,128,4,4)

            View((-1,128*4*4)),
            nn.Linear(128*4*4, 1),
        )


        # 손실 함수 생성
        self.loss_function = self.wasserstein
        # 옵티마이저 생성
        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.0005)
    def forward(self,x):
        return self.model(x)

    def wasserstein(self,y_pred,y_true):
        return torch.mean(y_true*y_pred)

    def train(self,x_batch,label):
        self.optimizer.zero_grad()
        y_pred = self.forward(x_batch)
        
        loss = self.loss_function(y_pred,label)
        loss.backward()
        self.optimizer.step()
        return loss
        ## weight clipping하면 clipped grad 어떻게 계산되지?

In [8]:
class G(nn.Module):
    def __init__(self,dim_latent=100):
        # 파이토치 부모 클래스 초기화
        super().__init__()
       # 신경망 레이어 정의
        self.dim_latent = dim_latent
        self.model = nn.Sequential(
            #(-1,100)
            nn.Linear(self.dim_latent,64*7*7),
            View((-1,64,7,7)),
            #(-1,3136)
            #(-1,64,7,7)
            
            nn.ConvTranspose2d(64,64,4,stride=2,padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.2),            
            #(-1,64,14,14)

            nn.ConvTranspose2d(64,128,5,padding=2),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.2),
            #(-1,128,14,14)

            nn.ConvTranspose2d(128,64,6,stride=2,padding=2),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.2),
            #(-1,64,28,28)

            nn.ConvTranspose2d(64,1,5,padding=2),
            nn.Tanh(),
            #(-1,1,28,28)

        )


        # 옵티마이저 생성
        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.0001)
    def forward(self,x):
        return self.model(x)

    def train(self,D,latent_batch,label):
    
        self.optimizer.zero_grad()
        ## D.optimiser.zero_grad()  필요없음
        g_output = self.forward(latent_batch)
        d_output = D(g_output)

        loss = D.loss_function(d_output,label)
        loss.backward()
        self.optimizer.step()
        # D(x_input,1) ->D학습
        # D(G(latent), 0) ->D학습
        # D(G(latent),1) ->G학습
        return loss


In [9]:
from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np

class MyDataset(Dataset):
    def __init__(self,foldername='camel'):
        (self.x_train, self.y_train) =load_safari(foldername)
    def __len__(self):
        return len(self.x_train)

    def __getitem__(self, idx):
        
        x_train = self.x_train.transpose((0,3,1,2))
        return torch.FloatTensor(x_train[idx])
        #return torch.cuda.FloatTensor(self.x_train[idx])
#data_loader = DataLoader(MyDataset(), batch_size=128, shuffle=True,generator=torch.Generator(device='cuda'))
data_loader = DataLoader(MyDataset(), batch_size=64, shuffle=True)


In [None]:
d = D()
g= G()
#d.to(device)
#g.to(device)

d_loss_hist = []
g_loss_hist =[]

for epoch in range(1, 101):
    for i,x_input in enumerate(data_loader):

        pass
        label = torch.ones((x_input.shape[0],1))
        fake_label = -torch.ones((x_input.shape[0],1))
        d_loss=d.train(x_input,label)
        d_fake_loss=d.train(g(torch.randn(64,100)).detach(),fake_label)
        
        with torch.no_grad():
            for param in d.parameters():
                param.clamp_(-0.01, 0.01)
        

        
        total_d_loss =(d_loss.item()+d_fake_loss.item())/2
        d_loss_hist.append(total_d_loss)
        if (i%5==4):
            g_loss=g.train(d,torch.randn(64,100),label)
            g_loss_hist.append(g_loss.item())
            print(f"Epoch {epoch}: d_loss = {total_d_loss:.9f}  g loss={g_loss.item():.9f}")
    if epoch % 1 == 0:
        print(f"Epoch {epoch}: d_loss = {total_d_loss:.9f}  g loss={g_loss.item():.9f}")


In [None]:
import matplotlib.pyplot as plt

fake_pred = g.forward(torch.randn(8,100)).detach().cpu().numpy() #(8,1,28,28)
fake_pred = fake_pred.transpose(0,2,3,1) #(8,28,28,1)
fake_8img = fake_pred.transpose(1,0,2,3).reshape(28,-1,1)
plt.figure(figsize=(16,2))
plt.imshow(fake_8img)

In [None]:
def visualize(loss_hist):
    plt.figure(figsize=(12,4))
    plt.scatter(np.arange(1,len(loss_hist)+1),loss_hist,s=0.5)
    plt.title('generator loss')
    plt.show()
visualize(g_loss_hist)
visualize(d_loss_hist)

AssertionError: 