Skip to content

Commit

Permalink
Refactored model to add CNN-MLP with same input and output stages
Browse files Browse the repository at this point in the history
  • Loading branch information
mdda committed Jul 5, 2017
1 parent 3cbc9aa commit c49490b
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 52 deletions.
13 changes: 10 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
import torch
from torch.autograd import Variable

import model
from model import RN, CNN_MLP


# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser = argparse.ArgumentParser(description='PyTorch Relational-Network sort-of-CLVR Example')
parser.add_argument('--model', type=str, choices=['RN', 'CNN_MLP'], default='RN',
help='resume from model stored')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--epochs', type=int, default=20, metavar='N',
Expand All @@ -40,7 +43,11 @@
if args.cuda:
torch.cuda.manual_seed(args.seed)

model = model.RN(args)
if args.model=='CNN_MLP':
model = CNN_MLP(args)
else:
model = RN(args)

model_dirs = './model'
bs = args.batch_size
input_img = torch.FloatTensor(bs, 3, 75, 75)
Expand Down
153 changes: 104 additions & 49 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import torch.optim as optim
from torch.autograd import Variable

class RN(nn.Module):

def __init__(self,args):
super(RN, self).__init__()
class ConvInputModel(nn.Module):
def __init__(self):
super(ConvInputModel, self).__init__()

self.conv1 = nn.Conv2d(3, 24, 3, stride=2, padding=1)
self.batchNorm1 = nn.BatchNorm2d(24)
self.conv2 = nn.Conv2d(24, 24, 3, stride=2, padding=1)
Expand All @@ -17,6 +18,73 @@ def __init__(self,args):
self.batchNorm3 = nn.BatchNorm2d(24)
self.conv4 = nn.Conv2d(24, 24, 3, stride=2, padding=1)
self.batchNorm4 = nn.BatchNorm2d(24)


def forward(self, img, qst):
"""convolution"""
x = self.conv1(img)
x = F.relu(x)
x = self.batchNorm1(x)
x = self.conv2(x)
x = F.relu(x)
x = self.batchNorm2(x)
x = self.conv3(x)
x = F.relu(x)
x = self.batchNorm3(x)
x = self.conv4(x)
x = F.relu(x)
x = self.batchNorm4(x)
return x


class FCOutputModel(nn.Module):
def __init__(self):
super(FCOutputModel, self).__init__()

self.fc2 = nn.Linear(256, 256)
self.fc3 = nn.Linear(256, 10)

def forward(self, x):
x = self.fc2(x)
x = F.relu(x)
x = F.dropout(x)
x = self.fc3(x)
return F.log_softmax(x)



class BasicModel(nn.Module):
def __init__(self, args, name):
super(BasicModel, self).__init__()
self.name=name

def train_(self, input_img, input_qst, label):
self.optimizer.zero_grad()
output = self(input_img, input_qst)
loss = F.nll_loss(output, label)
loss.backward()
self.optimizer.step()
pred = output.data.max(1)[1]
correct = pred.eq(label.data).cpu().sum()
accuracy = correct * 100. / len(label)
return accuracy

def test_(self, input_img, input_qst, label):
output = self(input_img, input_qst)
pred = output.data.max(1)[1]
correct = pred.eq(label.data).cpu().sum()
accuracy = correct * 100. / len(label)
return accuracy

def save_model(self, epoch):
torch.save(self.state_dict(), 'model/epoch_{}_{}.pth'.format(self.name, epoch))


class RN(BasicModel):
def __init__(self, args):
super(RN, self).__init__(args, 'RN')

self.conv = ConvInputModel()

##(number of filters per object+coordinate of object)*2+question vector
self.g_fc1 = nn.Linear((24+2)*2+11, 256)
Expand All @@ -26,10 +94,6 @@ def __init__(self,args):
self.g_fc4 = nn.Linear(256, 256)

self.f_fc1 = nn.Linear(256, 256)
self.f_fc2 = nn.Linear(256, 256)
self.f_fc3 = nn.Linear(256, 10)

self.optimizer = optim.Adam(self.parameters(), lr=args.lr)

self.coord_oi = torch.FloatTensor(args.batch_size, 2)
self.coord_oj = torch.FloatTensor(args.batch_size, 2)
Expand All @@ -40,54 +104,50 @@ def __init__(self,args):
self.coord_oj = Variable(self.coord_oj)

# prepare coord tensor
def cvt_coord(i):
return [(i/5-2)/2., (i%5-2)/2.]

self.coord_tensor = torch.FloatTensor(args.batch_size, 25, 2)
if args.cuda:
self.coord_tensor = self.coord_tensor.cuda()
self.coord_tensor = Variable(self.coord_tensor)
np_coord_tensor = np.zeros((args.batch_size, 25, 2))
for i in range(25):
np_coord_tensor[:,i,:] = np.array(self.cvt_coord(i))
np_coord_tensor[:,i,:] = np.array( cvt_coord(i) )
self.coord_tensor.data.copy_(torch.from_numpy(np_coord_tensor))

def cvt_coord(self, i):
return [(i/5-2)/2., (i%5-2)/2.]

self.fcout = FCOutputModel()

self.optimizer = optim.Adam(self.parameters(), lr=args.lr)

def forward(self, img, qst):
"""convolution"""
x = self.conv1(img)
x = F.relu(x)
x = self.batchNorm1(x)
x = self.conv2(x)
x = F.relu(x)
x = self.batchNorm2(x)
x = self.conv3(x)
x = F.relu(x)
x = self.batchNorm3(x)
x = self.conv4(x)
x = F.relu(x)
x = self.batchNorm4(x)
## x = (64 x 24 x 5 x 5)
x = self.conv(img, qst) ## x = (64 x 24 x 5 x 5)

"""g"""
mb = x.size()[0]
n_channels = x.size()[1]
d = x.size()[2]
# x_flat = (64 x 25 x 24)
x_flat = x.view(mb,n_channels,d*d).permute(0,2,1)

# add coordinates
x_flat = torch.cat([x_flat, self.coord_tensor],2)

# add question everywhere
qst = torch.unsqueeze(qst, 1)
qst = qst.repeat(1,25,1)
qst = torch.unsqueeze(qst, 2)

# cast all pairs against each other
x_i = torch.unsqueeze(x_flat,1) # (64x1x25x26+11)
x_i = x_i.repeat(1,25,1,1) # (64x25x25x26+11)
x_j = torch.unsqueeze(x_flat,2) # (64x25x1x26+11)
x_j = torch.cat([x_j,qst],3)
x_j = x_j.repeat(1,1,25,1) # (64x25x25x26+11)

# concatenate all together
x_full = torch.cat([x_i,x_j],3) # (64x25x25x2*26+11)

# reshape for passing through network
x_ = x_full.view(mb*d*d*d*d,63)
x_ = self.g_fc1(x_)
Expand All @@ -98,39 +158,34 @@ def forward(self, img, qst):
x_ = F.relu(x_)
x_ = self.g_fc4(x_)
x_ = F.relu(x_)

# reshape again and sum
x_g = x_.view(mb,d*d*d*d,256)
x_g = x_g.sum(1).squeeze()

"""f"""
x_f = self.f_fc1(x_g)
x_f = F.relu(x_f)
x_f = self.f_fc2(x_f)
x_f = F.relu(x_f)
x_f = F.dropout(x_f)
x_f = self.f_fc3(x_f)

return self.fcout(x_f)

return F.log_softmax(x_f)

class CNN_MLP(BasicModel):
def __init__(self, args):
super(CNN_MLP, self).__init__(args, 'CNNMLP')

def train_(self, input_img, input_qst, label):
self.optimizer.zero_grad()
output = self(input_img, input_qst)
loss = F.nll_loss(output, label)
loss.backward()
self.optimizer.step()
pred = output.data.max(1)[1]
correct = pred.eq(label.data).cpu().sum()
accuracy = correct * 100. / len(label)
return accuracy

self.conv = ConvInputModel()
self.fc1 = nn.Linear(5 * 5 * 24, 256)
self.fcout = FCOutputModel()

def test_(self, input_img, input_qst, label):
output = self(input_img, input_qst)
pred = output.data.max(1)[1]
correct = pred.eq(label.data).cpu().sum()
accuracy = correct * 100. / len(label)
return accuracy
self.optimizer = optim.Adam(self.parameters(), lr=args.lr)

def forward(self, img, qst):
x = self.conv(img, qst) ## x = (64 x 24 x 5 x 5)

"""fully connected layers"""
x = self.fc1(x.view(x.size(0), -1))
x = F.relu(x)

return self.fcout(x)

def save_model(self, epoch):
torch.save(self.state_dict(), 'model/epoch_{}.pth'.format(epoch))

0 comments on commit c49490b

Please sign in to comment.