In [1]:
%%writefile '../db/models/training_session.py'

from sqlalchemy import Column, Integer, Numeric, String, Text, VARCHAR, DECIMAL, DateTime, Float, Boolean, LargeBinary, Binary, SmallInteger, BigInteger
from sqlalchemy import select, delete, update, insert
from sqlalchemy.ext.declarative import declarative_base

from sqlalchemy import Integer, DateTime
from wsgiref.handlers import format_date_time

# from db import db_init
# from base import Base
# from db_base import DBBase
# from utils import Utils

from db import db_init
from base import Base
from db_base import DBBase
from utils import Utils

import numpy as np
from datetime import datetime

class TrainingSession(Base, DBBase):
    __tablename__ = 'train_sessions'
    __table_args__ = {'extend_existing': True}
    db = db_init('SQLITE')
    

    train_sess_id=Column(BigInteger, primary_key=True, autoincrement=False)
    modset_id = Column(VARCHAR(50))
    start_datetime=Column(DateTime)
    end_datetime=Column(DateTime)
    epochs_total=Column(Integer)
    batch_size=Column(Integer)
    avg_time_secs=Column(Integer)
    initial_checkpoint=Column(VARCHAR(50))
    
    
    def __init__(self, modsetID, trainSessID=0):
        self.reset();
        self.modset_id = modsetID
        if(trainSessID>0):
            self.train_sess_id = trainSessID;

        super().setupDBBase(TrainingSession, TrainingSession.train_sess_id, self.train_sess_id)
        
    def reset(self):
        self.epochs_total=0;
        self.batch_size=0;
        avg_time_secs = 0;
        
    def start_session(self, initialVersion=1, epochsTotal=1, batchSize=32):
        self.initial_checkpoint = initialVersion;
        self.epochs_total = epochsTotal;
        self.batch_size = batchSize;
        
        self.generate_sess_id();
        self.set_start_time();
        
        self.db_save();
        
        return self.train_sess_id;
        
    def end_session(self):
        self.set_end_time();
        self.avg_time_secs = int(self.get_time_elapsed()/self.epochs_total);
        
        self.db_save(); #until we fix db_update func
        #self.db_update({'end_datetime':self.end_datetime, 'avg_time_secs':self.avg_time_secs}, {'train_sess_id':self.train_sess_id})
        
        
    def generate_sess_id(self):
        sess_id_format = "%Y%m%d%H%M%S"
        self.date_time = datetime.now();
        
        self.set_session_id(self.date_time.strftime(sess_id_format))
        print("New Session ID: ", self.train_sess_id)
        return self.train_sess_id;
    
    def set_session_id(self, sess_id):
        self.train_sess_id = sess_id;
        
    def get_session_id(self):
        return self.train_sess_id;
        
    def set_start_time(self):
        self.start_datetime = datetime.now()
        
    def set_end_time(self):
        self.end_datetime = datetime.now();
        
    def get_time_elapsed(self):
        time_elapsed = datetime.now() - self.start_datetime;
        return time_elapsed.total_seconds();
        

Overwriting ../db/models/training_session.py


In [2]:
%%writefile '../db/models/training_log.py'

from sqlalchemy import Column, Integer, Numeric, String, Text, VARCHAR, DECIMAL, DateTime, Float, Boolean, LargeBinary, Binary, SmallInteger, BigInteger
from sqlalchemy import select, delete, update, insert
from sqlalchemy.ext.declarative import declarative_base

from sqlalchemy import Integer, DateTime
from wsgiref.handlers import format_date_time

# from db import db_init
# from base import Base
# from db_base import DBBase
# from utils import Utils

from db import db_init
from base import Base
from db_base import DBBase
from utils import Utils

import numpy as np
from datetime import datetime

class TrainingLog(Base, DBBase):
    __tablename__ = 'train_logs'
    __table_args__ = {'extend_existing': True}
    db = db_init('SQLITE')
    
    train_log_id=Column(Integer, primary_key=True, autoincrement=True)
    train_sess_id=Column(BigInteger)
    modset_id = Column(VARCHAR(50))
    datetime=Column(DateTime)
    epoch =Column(Integer)
    loss = Column(DECIMAL(35,30))
    metric_name = Column(VARCHAR(50))
    metric_value = Column(DECIMAL(35,30))
    
    
    def __init__(self, modsetID, trainSessID=0, trainLogID=0):
        self.modset_id = modsetID;
        self.train_sess_id = trainSessID;        
        if(trainLogID>0):
            self.train_log_id = trainLogID;

        super().setupDBBase(TrainingLog, TrainingLog.train_log_id, self.train_log_id)
        
    def reset(self):
        self.train_log_id=0
        self.epoch=0;
        self.loss=0;
        self.metric_name = '';
        self.metric_value=0;
        
    def set_log(self, Epoch, Loss, metricName, metricValue):
        self.epoch = Epoch;
        self.loss = Loss;
        self.metric_name = metricName;
        self.metric_value = metricValue;

        self.datetime = datetime.now();
        
        #lets save it to db
        self.db_save();
        


Overwriting ../db/models/training_log.py
