Skip to content

corenel/pytorch-adda

master
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Code

Latest commit

 

Git stats

Files

Permalink
Failed to load latest commit information.
Type
Name
Latest commit message
Commit time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

PyTorch-ADDA

A PyTorch implementation for Adversarial Discriminative Domain Adaptation.

Environment

  • Python 3.6
  • PyTorch 0.2.0

Usage

I only test on MNIST -> USPS, you can just run the following command:

python3 main.py

Network

In this experiment, I use three types of network. They are very simple.

  • LeNet encoder

    LeNetEncoder (
      (encoder): Sequential (
        (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
        (1): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
        (2): ReLU ()
        (3): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
        (4): Dropout2d (p=0.5)
        (5): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
        (6): ReLU ()
      )
      (fc1): Linear (800 -> 500)
    )
    
  • LeNet classifier

    LeNetClassifier (
      (fc2): Linear (500 -> 10)
    )
    
  • Discriminator

    Discriminator (
      (layer): Sequential (
        (0): Linear (500 -> 500)
        (1): ReLU ()
        (2): Linear (500 -> 500)
        (3): ReLU ()
        (4): Linear (500 -> 2)
        (5): LogSoftmax ()
      )
    )
    

Result

MNIST (Source) USPS (Target)
Source Encoder + Source Classifier 99.140000% 83.978495%
Target Encoder + Source Classifier 97.634409%

Domain Adaptation does work (97% vs 83%).

About

A PyTorch implementation for Adversarial Discriminative Domain Adaptation

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages