Skip to content

Commit

Permalink
upd: DAN code
Browse files Browse the repository at this point in the history
  • Loading branch information
easezyc committed Oct 22, 2019
1 parent 90fbeed commit fbec92e
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 114 deletions.
134 changes: 65 additions & 69 deletions code/deep/DAN/DAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,20 @@
import data_loader
import ResNet as models
from torch.utils import model_zoo
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Training settings
batch_size = 32
epochs = 200
iteration=10000
lr = 0.01
momentum = 0.9
no_cuda =False
seed = 8
log_interval = 10
l2_decay = 5e-4
root_path = "./dataset/"
source_name = "amazon"
target_name = "webcam"
root_path = "/data/zhuyc/OFFICE31/"
src_name = "amazon"
tgt_name = "dslr"

cuda = not no_cuda and torch.cuda.is_available()

Expand All @@ -32,93 +32,89 @@

kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}

source_loader = data_loader.load_training(root_path, source_name, batch_size, kwargs)
target_train_loader = data_loader.load_training(root_path, target_name, batch_size, kwargs)
target_test_loader = data_loader.load_testing(root_path, target_name, batch_size, kwargs)
src_loader = data_loader.load_training(root_path, src_name, batch_size, kwargs)
tgt_train_loader = data_loader.load_training(root_path, tgt_name, batch_size, kwargs)
tgt_test_loader = data_loader.load_testing(root_path, tgt_name, batch_size, kwargs)

len_source_dataset = len(source_loader.dataset)
len_target_dataset = len(target_test_loader.dataset)
len_source_loader = len(source_loader)
len_target_loader = len(target_train_loader)
src_dataset_len = len(src_loader.dataset)
tgt_dataset_len = len(tgt_test_loader.dataset)
src_loader_len = len(src_loader)
tgt_loader_len = len(tgt_train_loader)

def load_pretrain(model):
url = 'https://download.pytorch.org/models/resnet50-19c8e357.pth'
pretrained_dict = model_zoo.load_url(url)
model_dict = model.state_dict()
for k, v in model_dict.items():
if not "cls_fc" in k:
model_dict[k] = pretrained_dict[k[k.find(".") + 1:]]
model.load_state_dict(model_dict)
return model

def train(epoch, model):
LEARNING_RATE = lr / math.pow((1 + 10 * (epoch - 1) / epochs), 0.75)
print('learning rate{: .4f}'.format(LEARNING_RATE) )
optimizer = torch.optim.SGD([
def train(model):
src_iter = iter(src_loader)
tgt_iter = iter(tgt_train_loader)
correct = 0
for i in range(1, iteration+1):
model.train()
LEARNING_RATE = lr / math.pow((1 + 10 * (i - 1) / (iteration)), 0.75)
if (i-1)%100==0:
print('learning rate{: .4f}'.format(LEARNING_RATE) )
optimizer = torch.optim.SGD([
{'params': model.sharedNet.parameters()},
{'params': model.cls_fc.parameters(), 'lr': LEARNING_RATE},
], lr=LEARNING_RATE / 10, momentum=momentum, weight_decay=l2_decay)

model.train()

iter_source = iter(source_loader)
iter_target = iter(target_train_loader)
num_iter = len_source_loader
for i in range(1, num_iter):
data_source, label_source = iter_source.next()
data_target, _ = iter_target.next()
if i % len_target_loader == 0:
iter_target = iter(target_train_loader)
try:
src_data, src_label = src_iter.next()
except Exception as err:
src_iter=iter(src_loader)
src_data, src_label = src_iter.next()

try:
tgt_data, _ = tgt_iter.next()
except Exception as err:
tgt_iter=iter(tgt_train_loader)
tgt_data, _ = tgt_iter.next()

if cuda:
data_source, label_source = data_source.cuda(), label_source.cuda()
data_target = data_target.cuda()
data_source, label_source = Variable(data_source), Variable(label_source)
data_target = Variable(data_target)
src_data, src_label = src_data.cuda(), src_label.cuda()
tgt_data = tgt_data.cuda()

optimizer.zero_grad()
label_source_pred, loss_mmd = model(data_source, data_target)
loss_cls = F.nll_loss(F.log_softmax(label_source_pred, dim=1), label_source)
gamma = 2 / (1 + math.exp(-10 * (epoch) / epochs)) - 1
loss = loss_cls + gamma * loss_mmd
src_pred, mmd_loss = model(src_data, tgt_data)
cls_loss = F.nll_loss(F.log_softmax(src_pred, dim=1), src_label)
lambd = 2 / (1 + math.exp(-10 * (i) / iteration)) - 1
loss = cls_loss + lambd * mmd_loss
loss.backward()
optimizer.step()
if i % log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tsoft_Loss: {:.6f}\tmmd_Loss: {:.6f}'.format(
epoch, i * len(data_source), len_source_dataset,
100. * i / len_source_loader, loss.data[0], loss_cls.data[0], loss_mmd.data[0]))

print('Train iter: {} [({:.0f}%)]\tLoss: {:.6f}\tsoft_Loss: {:.6f}\tmmd_Loss: {:.6f}'.format(
i, 100. * i / iteration, loss.item(), cls_loss.item(), mmd_loss.item()))

if i%(log_interval*20)==0:
t_correct = test(model)
if t_correct > correct:
correct = t_correct
print('src: {} to tgt: {} max correct: {} max accuracy{: .2f}%\n'.format(
src_name, tgt_name, correct, 100. * correct / tgt_dataset_len ))

def test(model):
model.eval()
test_loss = 0
correct = 0

for data, target in target_test_loader:
if cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data, volatile=True), Variable(target)
s_output, t_output = model(data, data)
test_loss += F.nll_loss(F.log_softmax(s_output, dim = 1), target, size_average=False).data[0] # sum up batch loss
pred = s_output.data.max(1)[1] # get the index of the max log-probability
correct += pred.eq(target.data.view_as(pred)).cpu().sum()

test_loss /= len_target_dataset
with torch.no_grad():
for tgt_test_data, tgt_test_label in tgt_test_loader:
if cuda:
tgt_test_data, tgt_test_label = tgt_test_data.cuda(), tgt_test_label.cuda()
tgt_test_data, tgt_test_label = Variable(tgt_test_data), Variable(tgt_test_label)
tgt_pred, mmd_loss = model(tgt_test_data, tgt_test_data)
test_loss += F.nll_loss(F.log_softmax(tgt_pred, dim = 1), tgt_test_label, reduction='sum').item() # sum up batch loss
pred = tgt_pred.data.max(1)[1] # get the index of the max log-probability
correct += pred.eq(tgt_test_label.data.view_as(pred)).cpu().sum()

test_loss /= tgt_dataset_len
print('\n{} set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
target_name, test_loss, correct, len_target_dataset,
100. * correct / len_target_dataset))
tgt_name, test_loss, correct, tgt_dataset_len,
100. * correct / tgt_dataset_len))
return correct


if __name__ == '__main__':
model = models.DANNet(num_classes=31)
correct = 0
print(model)
if cuda:
model.cuda()
model = load_pretrain(model)
for epoch in range(1, epochs + 1):
train(epoch, model)
t_correct = test(model)
if t_correct > correct:
correct = t_correct
print('source: {} to target: {} max correct: {} max accuracy{: .2f}%\n'.format(
source_name, target_name, correct, 100. * correct / len_target_dataset ))
train(model)

12 changes: 7 additions & 5 deletions code/deep/DAN/README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
# DAN
A PyTorch implementation of '[Learning Transferable Features with Deep Adaptation Networks](http://ise.thss.tsinghua.edu.cn/~mlong/doc/deep-adaptation-networks-icml15.pdf)'.
The contributions of this paper are summarized as fol-
lows.
The contributions of this paper are summarized as follows.
* They propose a novel deep neural network architecture for domain adaptation, in which all the layers corresponding to task-specific features are adapted in a layerwise manner, hence benefiting from “deep adaptation.”
* They explore multiple kernels for adapting deep representations, which substantially enhances adaptation effectiveness compared to single kernel methods. Our model can yield unbiased deep features with statistical guarantees.

## Requirement
* python 3
* pytorch 0.3.1
* torchvision 0.2.0
* pytorch 1.0

## Usage
1. You can download Office31 dataset [here](https://pan.baidu.com/s/1o8igXT4#list/path=%2F). And then unrar dataset in ./dataset/.
Expand All @@ -18,4 +16,8 @@ lows.
## Results on Office31
| Method | A - W | D - W | W - D | A - D | D - A | W - A | Average |
|:--------------:|:-----:|:-----:|:-----:|:-----:|:----:|:----:|:-------:|
| DAN | 83.8±0.4 | 96.8±0.2 | 99.5±0.1 | 78.4±0.2 | 66.7±0.3 | 62.7±0.2 | 81.3 |
| DANori | 83.8±0.4 | 96.8±0.2 | 99.5±0.1 | 78.4±0.2 | 66.7±0.3 | 62.7±0.2 | 81.3 |
| DANlast | 81.6±0.7 | 97.2±0.1 | 99.5±0.1 | 80.0±0.7 | 66.2±0.6 | 65.6±0.4 | 81.7 |
| DANmax | 82.6±0.7 | 97.7±0.1 | 100.0±0.0 | 83.1±0.9 | 66.8±0.3 | 66.6±0.4 | 82.8 |

> Note that the results **DANori** comes from [paper](http://ise.thss.tsinghua.edu.cn/~mlong/doc/multi-adversarial-domain-adaptation-aaai18.pdf) which has the same author as DAN. The **DANlast** is the results of the last epoch, and **DANmax** is the results of the max results in all epoches. Both **DANlast** and **DANmax** are run by myself with the code.
76 changes: 49 additions & 27 deletions code/deep/DAN/ResNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,16 @@
}


def conv3x3(in_planes, out_planes, stride=1):
def conv3x3(in_planes, out_planes, stride=1,groups=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)

def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)



class BasicBlock(nn.Module):
expansion = 1
Expand Down Expand Up @@ -54,7 +59,8 @@ def forward(self, x):
class Bottleneck(nn.Module):
expansion = 4

def __init__(self, inplanes, planes, stride=1, downsample=None):
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, norm_layer=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
Expand Down Expand Up @@ -92,43 +98,58 @@ def forward(self, x):

class ResNet(nn.Module):

def __init__(self, block, layers, num_classes=1000):
self.inplanes = 64
def __init__(self, block, layers, num_classes=1000,zero_init_residual=False,
groups=1, width_per_group=64, norm_layer=None):
super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
if norm_layer is None:
norm_layer = nn.BatchNorm2d

self.inplanes = 64
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AvgPool2d(7, stride=1)
self.baselayer = [self.conv1, self.bn1, self.layer1, self.layer2, self.layer3, self.layer4]
self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)

for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()

def _make_layer(self, block, planes, blocks, stride=1):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)

def _make_layer(self, block, planes, blocks, stride=1,norm_layer=None):
if norm_layer is None:
norm_layer = nn.BatchNorm2d
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)

layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, norm_layer))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, norm_layer=norm_layer))

return nn.Sequential(*layers)

Expand All @@ -144,14 +165,15 @@ def forward(self, x):
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
#x=self.fc(x)

return x

class DANNet(nn.Module):

def __init__(self, num_classes=31):
super(DANNet, self).__init__()
self.sharedNet = resnet50(False)
self.sharedNet = resnet50(True)
self.cls_fc = nn.Linear(2048, num_classes)

def forward(self, source, target):
Expand All @@ -176,4 +198,4 @@ def resnet50(pretrained=False, **kwargs):
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
return model
return model
1 change: 0 additions & 1 deletion code/deep/DAN/dataset/README.md

This file was deleted.

12 changes: 0 additions & 12 deletions code/deep/DAN/mmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,6 @@

import torch

# Consider linear time MMD with a linear kernel:
# K(f(x), f(y)) = f(x)^Tf(y)
# h(z_i, z_j) = k(x_i, x_j) + k(y_i, y_j) - k(x_i, y_j) - k(x_j, y_i)
# = [f(x_i) - f(y_i)]^T[f(x_j) - f(y_j)]
#
# f_of_X: batch_size * k
# f_of_Y: batch_size * k
def mmd_linear(f_of_X, f_of_Y):
delta = f_of_X - f_of_Y
loss = torch.mean(torch.mm(delta, torch.transpose(delta, 0, 1)))
return loss

def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
n_samples = int(source.size()[0])+int(target.size()[0])
total = torch.cat([source, target], dim=0)
Expand Down

0 comments on commit fbec92e

Please sign in to comment.