In [37]:
require 'optim'
require 'nn'


Definition of DANN

In [38]:
if not opt then
   print '==> processing options'
   cmd = torch.CmdLine()
   cmd:text()
   cmd:text('Deep Learning - Telecom tutorial')
   cmd:text()
   cmd:text('Options:')
   cmd:option('-learningRate', 0.1, 'learning rate at t=0')
   cmd:option('-domainLambda', 0.1, 'regularization term for transfer learning')
   cmd:option('-batchSize', 500, 'mini-batch size (1 = pure stochastic)')
   cmd:option('-maxEpoch', 50, 'maximum nb of epoch')
   cmd:option('-seed', 0, 'random seed')
   cmd:option('-saveModel', false, 'flag for saving the model on disk at each epoch, if improvement')
   cmd:option('-save', 'results', 'subdirectory to save/log experiments in')
   cmd:text()
   opt = cmd:parse(arg or {})
end
torch.manualSeed(opt.seed)
trainLogger = optim.Logger(paths.concat("results", opt.save))
trainLogger:setNames({'epoch','trainLoss','trainAccuracy','validLoss','validAccuracy'})


# Load the source data, session 1 person 1

In [39]:
 nTrainS1 = 6000
 nValidS1 = 6000
 EEGDataset1 = torch.load('EEG.torch')
 nInputEEG = EEGDataset1:size(3)

sourceTrainSet = torch.Tensor(nTrainS1, nInputEEG)
sourceTrainSet:copy(EEGDataset1[1]:narrow(1,1,nTrainS1):float())
sourceTrainSetLabel = torch.Tensor(nTrainS1, nInputEEG)
sourceTrainSetLabel:copy(EEGDataset1[1]:narrow(1,1,nTrainS1):float())

sourceValidSet = torch.Tensor(nTrainS1,nInputEEG)
sourceValidSet:copy(EEGDataset1[1]:narrow(1,nTrainS1+1,nValidS1):float())
sourceValidSetLabel = torch.Tensor(nValidS1, nInputEEG)
sourceValidSetLabel:copy(EEGDataset1[1]:narrow(1,nTrainS1+1,nValidS1):float())

sourceInputs = torch.Tensor(opt.batchSize,sourceTrainSet:size(2)) 
sourceLabels = torch.Tensor(opt.batchSize,sourceTrainSet:size(2))                                        


# Load the target data, session 2 person 1

In [40]:
 nTrainS2 = 6000
 nValidS2 = 6000
 EEGDataset2 = torch.load('EEGMat.torch')
 nInputEEG = EEGDataset2:size(3)

targetTrainSet = torch.Tensor(nTrainS2, nInputEEG)
targetTrainSet:copy(EEGDataset2[1]:narrow(1,1,nTrainS2):float())
targetTrainSetLabel = torch.Tensor(nTrainS2, nInputEEG)
targetTrainSetLabel:copy(EEGDataset2[1]:narrow(1,1,nTrainS2):float())

targetValidSet = torch.Tensor(nTrainS2,nInputEEG)
targetValidSet:copy(EEGDataset2[1]:narrow(1,nTrainS2+1,nValidS2):float())
targetValidSetLabel = torch.Tensor(nValidS2, nInputEEG)
targetValidSetLabel:copy(EEGDataset2[1]:narrow(1,nTrainS2+1,nValidS2):float())

targetInputs = torch.Tensor(opt.batchSize,targetTrainSet:size(2)) 
targetLabels = torch.Tensor(opt.batchSize,targetTrainSet:size(2)) 


# Definition of DANN structure

In [41]:
hiddenUnits = 5
-- Definition of the encoder
featExtractor = nn.Sequential()
--featExtractor:add(nn.Reshape(nInputEEG))
featExtractor:add(nn.Linear(nInputEEG,hiddenUnits))
featExtractor:add(nn.Sigmoid())
-- featExtractor:add(nn.Linear(5,5))

-- Definition of the decoder
labelPredictor = nn.Sequential()
labelPredictor:add(nn.Linear(hiddenUnits,nInputEEG))
labelPredictor:add(nn.Sigmoid())

-- Definition of the domain classifier
domainClassifier = nn.Sequential()
domainClassifier:add(nn.GradientReversal([lambda = 5]))
domainClassifier:add(nn.Linear(hiddenUnits,1))
domainClassifier:add(nn.Sigmoid())


[string "hiddenUnits = 5..."]:16: unexpected symbol near '[': 

In [42]:
-- Definition of the criterion
labelPredictorCriterion = nn.MSECriterion()
domainClassifierCriterion = nn.BCECriterion()

In [43]:
-- Retrieve the pointers to the parameters and gradParameters from the model for latter use
featExtractorParams,featExtractorGradParams = featExtractor:getParameters()
labelPredictorParams,labelPredictorGradParams = labelPredictor:getParameters()
domainClassifierParams,domainClassifierGradParams = domainClassifier:getParameters()
params = torch.Tensor(featExtractorParams:size(1)+labelPredictorParams:size(1)+domainClassifierParams:size(1))
params:narrow(1,1,featExtractorParams:size(1)):copy(featExtractorParams)
params:narrow(1,featExtractorParams:size(1),labelPredictorParams:size(1)):copy(labelPredictorParams)
params:narrow(1,featExtractorParams:size(1)+labelPredictorParams:size(1),domainClassifierParams:size(1)):copy(domainClassifierParams)
gradParams = torch.Tensor(featExtractorParams:size(1)+labelPredictorParams:size(1)+domainClassifierParams:size(1))
--

print("feat " .. tostring(featExtractorParams:size()))
print("label " .. tostring(labelPredictorParams:size()))
print("domain " .. tostring(domainClassifierParams:size()))


feat  285
[torch.LongStorage of size 1]
	
label  336
[torch.LongStorage of size 1]
	
domain  6
[torch.LongStorage of size 1]
	


# Redefine the train() function

In [44]:

-- Learning function
function train()

   local tick1 = sys.clock()
   
   -- It may help to shuffle the examples
   shuffle = torch.randperm(sourceTrainSet:size(1))
   
   for t = 1,sourceTrainSet:size(1),opt.batchSize do
	  
	  xlua.progress(t,sourceTrainSet:size(1))
	  
	  -- Define the minibatch
	  for i = 1,opt.batchSize do
		 sourceInputs[i]:copy(sourceTrainSet[shuffle[t+i-1]])
		 sourceLabels[i] = sourceTrainSetLabel[shuffle[t+i-1]]
		 targetInputs[i]:copy(targetTrainSet[shuffle[t+i-1]])
		 targetLabels[i] = targetTrainSetLabel[shuffle[t+i-1]]
	  end

	  -- Definition of the evaluation function (closure)
	  local feval = function(x)
		 
		 --featExtractorParams:copy(x)
		 featExtractorParams:copy(x:narrow(1,1,featExtractorParams:size(1)))
		 labelPredictorParams:copy(x:narrow(1,featExtractorParams:size(1),labelPredictorParams:size(1)))
		 domainClassifierParams:copy(x:narrow(1,featExtractorParams:size(1)+labelPredictorParams:size(1),domainClassifierParams:size(1)))
		 
		 featExtractorGradParams:zero()
		 labelPredictorGradParams:zero()
		 domainClassifierGradParams:zero()

		 local feats = featExtractor:forward(sourceInputs)
		 local preds = labelPredictor:forward(feats)
		 local labelCost = labelPredictorCriterion:forward(preds,sourceLabels)

		 -- print("Label cost ".. tostring(labelCost))

		 local labelDfdo = labelPredictorCriterion:backward(preds, sourceLabels)
		 local gradLabelPredictor = labelPredictor:backward(feats, labelDfdo)
		 featExtractor:backward(sourceInputs, gradLabelPredictor)

		 local domPreds = domainClassifier:forward(feats)
		 local domCost = domainClassifierCriterion:forward(domPreds,torch.Tensor(domPreds:size(1),1):fill(1)) -- TODO: ugly, replace with two unique allocations
		 local domDfdo = domainClassifierCriterion:backward(domPreds,torch.Tensor(domPreds:size(1),1):fill(1))
		 local gradDomainClassifier = domainClassifier:backward(feats,domDfdo,opt.domainLambda) --TODO: verify
		 featExtractor:backward(sourceInputs, gradDomainClassifier,opt.domainLambda)

		 --- Target propagation
		 local targetFeats = featExtractor:forward(targetInputs)
		 local targetDomPreds = domainClassifier:forward(targetFeats)
		 local targetDomCost = domainClassifierCriterion:forward(targetDomPreds,torch.Tensor(targetDomPreds:size(1),1):fill(0)) -- TODO: ugly, replace with two unique allocations
		 local targetDomDfdo = domainClassifierCriterion:backward(targetDomPreds,torch.Tensor(targetDomPreds:size(1),1):fill(0))
		 local targetGradDomainClassifier = domainClassifier:backward(targetFeats,targetDomDfdo,opt.domainLambda) --TODO: verify
		 featExtractor:backward(targetInputs, targetGradDomainClassifier,opt.domainLambda)

		 -- print("Domain cost ".. tostring(domCost+targetDomCost))

		 params:narrow(1,1,featExtractorParams:size(1)):copy(featExtractorParams)
		 params:narrow(1,featExtractorParams:size(1),labelPredictorParams:size(1)):copy(labelPredictorParams)
		 params:narrow(1,featExtractorParams:size(1)+labelPredictorParams:size(1),domainClassifierParams:size(1)):copy(domainClassifierParams)
		 gradParams:narrow(1,1,featExtractorGradParams:size(1)):copy(featExtractorGradParams)
		 gradParams:narrow(1,featExtractorGradParams:size(1),labelPredictorGradParams:size(1)):copy(labelPredictorGradParams)
		 gradParams:narrow(1,featExtractorGradParams:size(1)+labelPredictorParams:size(1),domainClassifierGradParams:size(1)):copy(domainClassifierGradParams)

		 return params,gradParams		 
	  end
	  optim.sgd(feval,params,opt)

   end
   print("tick" .. sys.clock()-tick1)
end

# Main function

In [45]:
prevLoss = 10e12

numberEpoch = torch.zero(opt.learningRate, opt.maxEpoch)
MSEepoch = torch.zero(opt.learningRate, opt.maxEpoch)

number
for 
for i = 1,opt.maxEpoch do
   --print('Degug here !!!!!!!!!!!!')
   featExtractor:evaluate()
   labelPredictor:evaluate()
   domainClassifier:evaluate()

   local sourceFeats = featExtractor:forward(sourceTrainSet)
   local sourceTrainPred = labelPredictor:forward(sourceFeats)
   local sourceTrainLoss = labelPredictorCriterion:forward(sourceTrainPred, sourceTrainSetLabel) 
   local sourceDomPreds = domainClassifier:forward(sourceFeats)
   local sourceDomCost = domainClassifierCriterion:forward(sourceDomPreds,torch.Tensor(sourceDomPreds:size(1),1):fill(1)) -- TODO: ugly, replace with two unique allocations

   local targetFeats = featExtractor:forward(targetTrainSet)
   local targetDomPreds = domainClassifier:forward(targetFeats)
   local targetDomCost = domainClassifierCriterion:forward(targetDomPreds,torch.Tensor(targetDomPreds:size(1),1):fill(1)) -- TODO: ugly, replace with two unique allocations

   -- trainConfusion:batchAdd(sourceTrainPred, sourceTrainSetLabel)
   print("EPOCH: " .. i)
   -- print(trainConfusion)
   print(" + Train loss " .. sourceTrainLoss .. " " .. sourceDomCost+targetDomCost)

   local validPred = labelPredictor:forward(featExtractor:forward(sourceValidSet))
   local validLoss = labelPredictorCriterion:forward(validPred, sourceValidSetLabel) 

   local sourceFeats = featExtractor:forward(sourceValidSet)
   local sourceValidPred = labelPredictor:forward(sourceFeats)
   local sourceValidLoss = labelPredictorCriterion:forward(sourceValidPred, sourceValidSetLabel) 
   local sourceDomPreds = domainClassifier:forward(sourceFeats)
   local sourceDomCostValid = domainClassifierCriterion:forward(sourceDomPreds,torch.Tensor(sourceDomPreds:size(1),1):fill(1)) -- TODO: ugly, replace with two unique allocations

   local targetFeats = featExtractor:forward(targetValidSet)
   local targetDomPreds = domainClassifier:forward(targetFeats)
   local targetDomCostValid = domainClassifierCriterion:forward(targetDomPreds,torch.Tensor(targetDomPreds:size(1),1):fill(1)) -- TODO: ugly, replace with two unique allocations

   -- validConfusion:batchAdd(validPred, sourceValidSetLabel)
   -- print(validConfusion)
   print(" + Valid loss " .. validLoss .. " " .. sourceDomCostValid+targetDomCostValid)

   -- trainLogger:add{i, trainLoss, trainConfusion.totalValid * 100, validLoss, validConfusion.totalValid * 100}
   -- trainConfusion:zero()
   -- validConfusion:zero()

   if opt.saveModel then
	  if trainLoss < prevLoss then
		 prevLoss = trainLoss
		 torch.save("model.bin",model)
	  else
		 model = torch.load("model.bin")
	  end
   end

   featExtractor:training()
   labelPredictor:training()
   domainClassifier:training()
   train()
end


EPOCH: 1	
 + Train loss 0.021105947069843 1.3914257887728	


 + Valid loss 0.019533513803055 1.3896508546989	


tick0.25401210784912	


EPOCH: 2	
 + Train loss 0.020918737900082 1.3912878743633	


 + Valid loss 0.019383144157362 1.3895052199485	


tick0.34698390960693	




EPOCH: 3	
 + Train loss 0.020716805896761 1.3911574803387	


 + Valid loss 0.019221345560611 1.3893659805825	


tick0.37539291381836	


EPOCH: 4	
 + Train loss 0.020516959380778 1.3910427929954	


 + Valid loss 0.019061799662885 1.3892421041506	


tick0.26014995574951	


EPOCH: 5	
 + Train loss 0.020319382362398 1.3909385630102	


 + Valid loss 0.018904478681414 1.3891285470995	


tick0.26050305366516	


EPOCH: 6	
 + Train loss 0.020123823479712 1.3908516933332	


 + Valid loss 0.018749138544213 1.3890316831236	


tick0.25833582878113	


EPOCH: 7	
 + Train loss 0.019930238687972 1.390770924091	


 + Valid loss 0.018595837187746 1.3889407511028	


tick0.38537406921387	


EPOCH: 8	
 + Train loss 0.019738481823579 1.3907006589593	


 + Valid loss 0.018444641429018 1.3888600141691	


tick0.25638604164124	


EPOCH: 9	
 + Train loss 0.019548774366649 1.3906396323455	


 + Valid loss 0.018295274718907 1.3887882482663	


tick0.25610208511353	


EPOCH: 10	
 + Train loss 0.019360719568713 1.3905840429943	


 + Valid loss 0.018147809549562 1.3887215126041	


tick0.38718891143799	


EPOCH: 11	
 + Train loss 0.019174663612223 1.3905359283852	


 + Valid loss 0.01800225711902 1.3886621997641	


tick0.24027299880981	


EPOCH: 12	
 + Train loss 0.018990213782217 1.3904902563655	


 + Valid loss 0.017858472326305 1.3886050685754	


tick0.25455498695374	


EPOCH: 13	
 + Train loss 0.018807612826495 1.3904505524786	


 + Valid loss 0.017716341736084 1.3885536407284	


tick0.25382709503174	


EPOCH: 14	
 + Train loss 0.018626358313515 1.3904112623019	


 + Valid loss 0.017575924549539 1.3885024318334	


tick0.31298303604126	


EPOCH: 15	
 + Train loss 0.01844710438545 1.3903789300205	


 + Valid loss 0.017437298702291 1.3884581538113	


tick0.18889808654785	


EPOCH: 16	
 + Train loss 0.018269084339236 1.3903464671296	


 + Valid loss 0.017300185330157 1.388413304281	


tick0.17957401275635	


EPOCH: 17	
 + Train loss 0.018092733602008 1.3903163557784	


 + Valid loss 0.017164687221165 1.3883708213307	


tick0.24998998641968	


EPOCH: 18	
 + Train loss 0.017917982328338 1.3902867414066	


 + Valid loss 0.017030752455367 1.3883286898607	


tick0.38257193565369	


EPOCH: 19	
 + Train loss 0.017744610077479 1.3902581709678	


 + Valid loss 0.016898440867408 1.3882875583692	


tick0.24127197265625	


EPOCH: 20	
 + Train loss 0.017572655852215 1.3902332721167	


 + Valid loss 0.016767519497299 1.3882497396638	


tick0.26531410217285	


EPOCH: 21	
 + Train loss 0.017402317202052 1.390205086698	


 + Valid loss 0.016638166755949 1.388208827041	


tick0.25915098190308	


EPOCH: 22	
 + Train loss 0.017233178699109 1.3901799464187	


 + Valid loss 0.01651022641821 1.3881706389842	


tick0.38977193832397	


EPOCH: 23	
 + Train loss 0.017065578108851 1.3901537004819	


 + Valid loss 0.016383765076647 1.3881314358838	


tick0.19489693641663	


EPOCH: 24	
 + Train loss 0.016899293396896 1.3901269399819	


 + Valid loss 0.01625880939548 1.3880917256624	


tick0.25109314918518	


EPOCH: 25	
 + Train loss 0.016734500936676 1.3901010607817	


 + Valid loss 0.016135175174972 1.3880528127734	


tick0.24711084365845	


EPOCH: 26	
 + Train loss 0.016570876439993 1.3900745563197	


 + Valid loss 0.016012877555998 1.3880131473459	


tick0.25326800346375	


EPOCH: 27	
 + Train loss 0.016408710684144 1.3900475835095	


 + Valid loss 0.015892090962996 1.3879731689146	


tick0.29470181465149	


EPOCH: 28	
 + Train loss 0.016247926647655 1.3900214295685	


 + Valid loss 0.015772504687798 1.3879338681082	


tick0.27108812332153	


EPOCH: 29	
 + Train loss 0.016088021818455 1.3899938978469	


 + Valid loss 0.015654202154886 1.3878929275417	


tick0.25234198570251	


EPOCH: 30	
 + Train loss 0.015929993009504 1.3899647147461	


 + Valid loss 0.015537456993349 1.3878509052984	


tick0.39005517959595	


EPOCH: 31	
 + Train loss 0.015773211465656 1.3899380859646	


 + Valid loss 0.015421948216688 1.3878112409453	


tick0.25341796875	


EPOCH: 32	


 + Train loss 0.015617376926831 1.3899074071559	


 + Valid loss 0.015307771893876 1.3877676033007	


tick0.25076103210449	


EPOCH: 33	
 + Train loss 0.015462873539441 1.3898775630085	


 + Valid loss 0.015194689754167 1.3877246556292	


tick0.28599095344543	


EPOCH: 34	
 + Train loss 0.015309682226885 1.3898471466779	


 + Valid loss 0.015083062150059 1.3876813024666	


tick0.33353710174561	


EPOCH: 35	
 + Train loss 0.015157850297615 1.3898180813295	


 + Valid loss 0.014972617214786 1.3876392981033	


tick0.25161790847778	


EPOCH: 36	
 + Train loss 0.015007266762905 1.3897846875315	


 + Valid loss 0.014863456421493 1.3875931975207	


tick0.24961018562317	


EPOCH: 37	
 + Train loss 0.014857659580963 1.3897513954456	


 + Valid loss 0.014755484233674 1.3875470396904	


tick0.24684119224548	


EPOCH: 38	
 + Train loss 0.014709488774768 1.3897185908706	


 + Valid loss 0.014648819546654 1.3875015890683	


tick0.33693599700928	


EPOCH: 39	
 + Train loss 0.014562488766253 1.3896862365536	


 + Valid loss 0.014543262719933 1.3874564772845	


tick0.23816800117493	


EPOCH: 40	
 + Train loss 0.014416831064868 1.389653342626	


 + Valid loss 0.014438975638274 1.3874110325723	


tick0.26891994476318	


EPOCH: 41	
 + Train loss 0.014272338056876 1.3896181759374	


 + Valid loss 0.014335934316432 1.3873635243391	


tick0.26651692390442	


EPOCH: 42	
 + Train loss 0.014128767099694 1.3895818903925	


 + Valid loss 0.014233907501831 1.3873146486268	


tick0.19757390022278	


EPOCH: 43	
 + Train loss 0.01398696350382 1.389547585498	


 + Valid loss 0.014133398242931 1.3872683140609	


tick0.19958996772766	


EPOCH: 44	
 + Train loss 0.013846030339996 1.3895129487277	


 + Valid loss 0.014033895999839 1.3872215296469	


tick0.19093108177185	


EPOCH: 45	
 + Train loss 0.013706505239908 1.3894775636395	


 + Valid loss 0.013935620116656 1.3871741486534	


tick0.38619112968445	


EPOCH: 46	
 + Train loss 0.013567835760991 1.3894401132777	


 + Valid loss 0.013838379793536 1.3871246122295	


tick0.16002798080444	


EPOCH: 47	
 + Train loss 0.013430739097957 1.3894019379394	


 + Valid loss 0.013742463583416 1.3870748252754	


tick0.16458606719971	


EPOCH: 48	
 + Train loss 0.013294663303753 1.3893684762687	


 + Valid loss 0.013647555605354 1.3870294841487	


tick0.16563510894775	


EPOCH: 49	
 + Train loss 0.013159919561743 1.3893299619971	


 + Valid loss 0.013553884597145 1.3869794791011	


tick0.24798393249512	


EPOCH: 50	
 + Train loss 0.013026273773509 1.3892921811724	


 + Valid loss 0.013461325413625 1.3869302191037	


tick0.16096186637878	
