In [None]:
import torch as T
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
import sklearn
import sklearn.pipeline
import sklearn.preprocessing
import numpy as np

In [None]:
class Environment(object):
    def __init__(self,threshold):
        self.threshold = threshold
        
    def sample_state_space(self):
        state_space = np.random.uniform(self.threshold,8.62)
        return np.array([state_space],dtype=np.float32)
        
    def reset(self):
        #there are 8 elements that define the state of the system
        self.heights = [32.1391165339522, 39.7411475115927, 48.8789681173444, 36.510653648217, 30.436839051169, 13.2002307908293, 36.2627740696941, 51.8926983676226, 44.5610562717815, 27.836062539881, 48.5091630138068, 48.5259168458412, 45.7316877077849, 42.8937370822758, 25.0964922186579, 39.643438883657, 38.2251063770881, 36.5818322363791, 34.0845002832226, 46.4385051254061, 19.6387697004397, 36.3843230813292, 40.3204128957389, 47.0620817994192, 13.3158316828047, 30.6433648244663, 39.4846385831796, 22.2326687254873, 50.2089880964485, 36.2130084098589, 46.6545549035838, 48.0218229304846, 46.822787268674, 7.63463593020801, 31.9720644202413, 34.19672954798, 26.1123264014439, 39.8783146693426, 39.7760536734762, 25.8272041169222, 32.6117431651581, 41.966477748219, 48.0573119652393, 48.9326650228638, 28.9171081131264, 47.3216967502979, 44.8332580434864, 40.2195397584721, 39.8147150889448, 37.0074219641041, 45.4609607749412, 36.0226553591759, 41.1568162125188, 43.4085846263739, 36.0265717318162, 40.6249681084501, 41.910549163324, 43.2343191689344, 35.0597589322889, 38.9922842796777, 38.240305178274, 40.0499490794622, 42.9428180997643, 36.1618772630845, 28.1492624218548, 34.3992955433712, 40.6883933092323, 26.9459650677862, 43.9331273265289, 45.2500821629896, 47.4009683166691, 26.6548564334386, 32.9805547485197, 41.6740943650716, 47.6028910901876, 30.0374865417283, 32.9795971473545, 30.3890390641409, 45.6818161934343, 36.2751544937301, 26.1756107216378, 49.8963862778469, 43.228559972641, 22.2547071512728, 40.2417390124435, 45.0700330664824, 19.9650328213718, 44.7140656179915, 40.5687106497537, 41.1774473343145, 54.5645652801029, 35.7633527501085, 35.1632160790364, 42.2085635582537, 42.6208883007481, 11.0021967872984, 31.8206535743656, 25.3968785203669, 35.911527586781, 26.2856515140912, 41.2365551334996, 46.2547562959813, 20.1656368697262, 30.3519186343145, 15.6006388269469, 33.0288416579266, 38.6018616670649, 41.3894084108647, 37.8221432893306, 34.1439916170005, 23.292082456716, 39.9801512599729, 34.9740052142451, 39.9512187577806, 21.0548633798168, 30.3592489529254, 25.914413066131, 29.4422709350614, 38.2990931720403, 30.6334620228661, 27.6413891453094, 23.3486997430876, 26.0945014698154, 45.4564268302348, 17.5619319855943, 47.3229701962721, 42.5378883322536, 23.174357767596, 32.616384395068, 37.3643427808566, 37.6952669916477, 37.067669094864, 47.7259585190478, 37.7590787458332, 31.5518881255921, 27.3653745987654, 41.6132320077718, 52.6455870005376, 20.6471563832088, 33.7321493076797, 46.599237414263, 20.6982713790297, 45.4676859744866, 38.6526014382027, 14.574105381066, 34.5376406222927, 31.7387299876553, 36.9272433542646, 55.9311040863286, 36.3100607048287, 28.9657228931701, 34.5173621938915, 37.3268617807314, 29.2962242994313, 37.3281293076907, 28.2641156012371, 32.9139158088578, 41.7175802375092, 20.7646638060207, 21.512975516172, 35.3789202094236, 38.3250288364242, 40.986937247471, 31.9878115348134, 53.4762701447589, 29.6855036419329, 40.3029150132003, 46.2175644016149, 41.8364457284671, 44.9710861142246, 27.9008533126531, 43.2384921309416, 36.033205677353, 43.748174483715, 30.6103539358361, 32.5021599348731, 43.4990114075966, 24.6315374788068, 22.6408444659924, 29.1433437858504, 44.5197116633473, 30.8214143096709, 17.5753936823897, 20.8013660089148, 33.2626726460821, 30.1993249231959, 22.7440954039959, 45.4999679552554, 21.2472068901813, 33.9931872712956, 25.3660872210954, 28.3015993597981, 36.1145847362743, 37.3975375262677, 55.5807558414351, 46.8542587687664, 36.2777044131077, 42.6444607003486, 52.5986859424595, 28.2511355537348, 28.9812575464968, 26.6816941309596, 51.5497261997561, 36.243732128047, 36.5136389839681, 27.9405148587164, 44.0931132886571, 39.0023676625522, 30.8012412921867, 50.6767053959512, 23.8478188532297, 24.4578416650816, 41.6367329546282, 54.3642381618693, 45.3844757553056, 37.5466074559827, 41.049377953872, 41.9350591613031, 34.630458486166, 41.0243359105546, 30.6242656889671, 43.6823152705401, 18.7752817150921, 31.0081550917716, 36.5329030131772, 35.0797525766792, 31.4499298064163, 42.6034302617637, 23.4757731708747, 37.433431458086, 45.3768484612159, 48.1447121489374, 36.3238350414623, 22.9841640245428, 37.6792328710191, 35.8600186030738, 47.1715168934356, 38.5184845785928, 16.3919426916972, 39.7550904667644, 44.2028849471299, 35.2429360387164, 37.5846442651797, 39.1764604315344, 40.2529727393167, 29.2817717269334, 33.5379655760722, 38.1782910971407, 28.7708065982077, 20.574963134752, 38.0968989073047, 34.071785536993, 28.3668624661823, 34.4630965397371, 53.7624654259135, 19.4615307960977, 33.3759394905262, 35.7019732938884, 32.4982525240062, 30.069216230496, 41.6950639703987, 47.4479224112131, 34.3778958590011, 47.5910449149626, 31.4011035873861, 31.5875198832872, 38.8558006036991, 16.7011699312904, 41.519676156467, 30.7276546953641, 30.3931736780406, 28.4548559120399, 30.6250055409953, 22.4041248809058, 33.7570880503435, 39.3019447190877, 52.8239652074375, 35.1437023584409, 55.5233631820805, 41.9557711122658, 46.7049458714273, 32.1165986065506, 12.8495009610989, 28.1317703247554, 27.4154551291574, 43.9194402283895, 31.8991324283584, 12.3541861585855, 40.2062434386218, 38.7285715856269, 38.1467853023835, 25.9203924540944, 52.4761863816713, 19.5944621065576, 33.8008466957476, 44.5738615607229, 10.9187325169424, 37.7560194521201, 46.8228464973401, 27.774158184623, 41.1820549824953, 39.194423903527, 29.4974281688342, 34.794533758244, 42.647174498672, 38.8908043933152, 49.0261991298253, 45.0224437700402, 34.8224162565277, 40.0496050536553, 39.5218342142435, 40.3866493872825, 27.6220993292993, 40.9400554477003, 30.7834811580336, 39.6232417964801, 41.6503184516988, 25.519111899288, 41.2807467068902, 36.948716391894, 37.7280829946377, 36.8274524347244, 26.6515038599076, 28.938120080123, 47.4889248472934, 21.8805619773665, 40.2448192636918, 33.3077634303685, 50.3478268696258, 40.1523561232484, 47.6578626160506, 34.9153051508087, 33.192812108129, 41.9035115188727, 39.1163287430565, 42.6014023253833, 26.8979853611176, 42.7354567967541, 35.6624026305555, 30.2131004270759, 48.3066606709751, 29.3300677857071, 30.738460341987, 37.4913980863107, 40.3050202576547, 50.4119187879535, 23.5170227211074, 35.1773401906397, 31.103587993847, 33.4533500810804, 41.7027665274737, 31.4563899027081, 37.4903520434992, 45.2592100298983, 31.7950558338685, 37.8952640885234, 29.3920887378514, 27.5508910204881, 37.4391545728686, 42.9429298379461, 42.7852025235093, 16.4656988047589, 40.6223668542976, 26.1371004492868, 43.3853600764388, 38.6260305385802, 38.3673098976726, 52.4612649312388, 47.9340497849501, 28.314825340239, 44.4270132486533, 39.8643908403145, 45.0828047783406, 51.0644446676963]
        self.radii = [0.113233924094511, 0.101748019099675, 0.437187629043035, 0.228214030885334, 0.319377830889296, 0.171718218543952, 0.234838471715822, 0.136593051549311, 0.340397040155141, 0.319229651988948, 0.356741477180556, 0.411208149987085, 0.121935094979036, 0.140385479282672, 0.270744858499564, 0.126962322851474, 0.128314439596748, 0.273687824957094, 0.134632768824437, 0.396435696458551, 0.135436591507821, 0.19543255007994, 0.101324459418655, 0.277068274311401, 0.253579852610039, 0.0970576851004748, 0.156256421263814, 0.0784217541134068, 0.102205121642536, 0.116965375569846, 0.0756141242223765, 0.232360468597763, 0.131333468538293, 0.0697422116051773, 0.168674352472348, 0.0764176014075452, 0.124708799965718, 0.138173225181789, 0.161921102914789, 0.0905398788563845, 0.123636820035505, 0.158427509604182, 0.250721324589674, 0.0810411038035561, 0.192756133031979, 0.141660032778402, 0.107273369492107, 0.137699968031425, 0.0639050418984785, 0.0962646969512781, 0.111711488087792, 0.0803383376563764, 0.112042466809329, 0.22636813178624, 0.10493443814606, 0.084858006433817, 0.180034710078094, 0.0921310168381199, 0.109530696252363, 0.107498909933044, 0.136822486324849, 0.0868722097349304, 0.106146948207313, 0.102159021881258, 0.0925637783171733, 0.126553407431249, 0.176891320574086, 0.134200086305452, 0.0735054682923434, 0.150985454845058, 0.13415293862226, 0.105784568328975, 0.11776444668559, 0.0964170918198695, 0.0821907586326462, 0.0737872818644906, 0.111606657695424, 0.174238559223905, 0.0965627968844809, 0.0682575858932976, 0.0856785136464029, 0.130420492816041, 0.107628089987886, 0.0926007589680865, 0.0642085849932048, 0.0827123215709131, 0.06750775499916, 0.0993970269903465, 0.0914871320771519, 0.0673415351193994, 0.0755534018962921, 0.0670744117323482, 0.0992359348787822, 0.0773010432199216, 0.100612341965999, 0.0837167933604057, 0.10584585925577, 0.149390457865821, 0.146552210040749, 0.0923168385159858, 0.0751426947698626, 0.104588370344252, 0.109559027366832, 0.0995369287337584, 0.0639582042203695, 0.080182708829321, 0.114221966778357, 0.0642309199356442, 0.118188662829656, 0.0676711148634725, 0.0775837843419188, 0.0831487644816609, 0.0830345795924003, 0.0935771830034747, 0.0902778973602378, 0.109395077239332, 0.0780857771841979, 0.0840716598079553, 0.160795186542398, 0.0925041407393677, 0.0724439823233031, 0.0942152742368031, 0.0785242111817895, 0.066045931387475, 0.0694600167644921, 0.0890285613528438, 0.0934523846756875, 0.103646597498166, 0.119388706489651, 0.0773596778025264, 0.0739658997305345, 0.110108974011096, 0.0968193395911311, 0.0737228493132348, 0.103355863385718, 0.104366341588327, 0.110597801666113, 0.102068472210233, 0.0637914130844278, 0.112124968524856, 0.121202943232005, 0.0667409962244479, 0.0938881446007618, 0.0687736889034075, 0.0686767760808669, 0.0766234057916105, 0.097195429793871, 0.0666901068255815, 0.0814082226123844, 0.115530264989523, 0.0936228355126131, 0.0720094422223202, 0.0760614136194917, 0.0642139792628998, 0.0671552879050814, 0.0890849897804774, 0.0858625999301789, 0.0642762037588933, 0.0743885145256234, 0.0770481716558152, 0.07941220179688, 0.0763113166025852, 0.0672059269361755, 0.0977674900789974, 0.0988911042070281, 0.0663716779407418, 0.064216468667195, 0.103524120166233, 0.116103169135675, 0.0695111977686663, 0.0746924626951454, 0.0718893464097982, 0.0807560885944797, 0.0673626627816155, 0.0942188937390215, 0.0640166941516658, 0.0798247497265043, 0.081864112350456, 0.0648962651650743, 0.0795761260471463, 0.080527396480742, 0.0669185265615649, 0.0786276125350952, 0.0753064331495827, 0.0808898107848158, 0.071088212398468, 0.0866476679182046, 0.063736005184869, 0.0893175824690016, 0.0855054252508395, 0.0957482198546107, 0.0699024499331349, 0.0741199169465697, 0.074593045638885, 0.0770690327469709, 0.117212785442814, 0.0792442508533426, 0.0650154754096118, 0.0811809351882745, 0.0649406114575854, 0.0767065343498129, 0.0717678062707422, 0.0848476781174467, 0.0732460766070507, 0.0764140798367262, 0.0700524772667642, 0.0739557136986116, 0.0732219844120573, 0.0816263363809335, 0.0683857884112691, 0.0644510017827481, 0.0949390079172598, 0.0654414046289605, 0.0770789274838451, 0.0694568321274686, 0.0704832497570491, 0.0917932343559573, 0.0683774872483854, 0.0881291050993092, 0.0694665603152745, 0.0720732623721641, 0.0723855829781667, 0.0637529836015904, 0.0683888959261593, 0.0680779483618791, 0.0655056136077274, 0.0744323031091506, 0.0655142154692383, 0.0640541324860992, 0.0700365901066015, 0.065529836004009, 0.0713012593730411, 0.0710496332955136, 0.072523777341037, 0.0682404494563306, 0.0643857449554819, 0.0851088280531139, 0.0796124994424785, 0.0848433504546301, 0.0646437384708561, 0.0761730914326844, 0.0643630630536189, 0.0937775780940989, 0.0713484347100066, 0.0735171640653013, 0.0655062234185408, 0.0697368519054537, 0.0720235354644835, 0.0738301700263561, 0.0639725378069966, 0.0638936327335303, 0.0697122890648149, 0.0662510854792211, 0.0637606701609987, 0.0689257596650248, 0.0858533098317801, 0.064220805856498, 0.0723699823703535, 0.0721031617323258, 0.090129110784952, 0.0785368835431961, 0.0681795955771215, 0.0637289751922699, 0.0771533345293758, 0.0650993317945627, 0.0655402809334693, 0.0870996188119037, 0.0706124004143381, 0.0662374578724887, 0.0824117665233451, 0.0667263889865738, 0.065601956995018, 0.083861193816201, 0.0659730562082503, 0.0657478205810517, 0.0734807157747719, 0.074214282736536, 0.0651566477703054, 0.0780312038488369, 0.0697717646490748, 0.0664266392623208, 0.0665186995615388, 0.06640661087734, 0.0683926365506579, 0.0665405914541578, 0.0658530164173775, 0.0735226688985105, 0.0638984856065352, 0.072564001804519, 0.0712870262588192, 0.0667957660086738, 0.0669746405868997, 0.0642840692106619, 0.0643816038665761, 0.0652338511872316, 0.0655634740446282, 0.073853891714147, 0.0773513956798712, 0.0698774119978302, 0.0657591905834395, 0.0659934227274224, 0.0647787817293954, 0.0661112636007874, 0.0750992201700891, 0.0660458137299479, 0.0703674103370386, 0.0694756071741951, 0.0693421085051897, 0.0643444087305131, 0.0719538611806128, 0.0671662281067296, 0.0642787622823467, 0.0680118544757353, 0.0705676359886643, 0.0766349053819236, 0.0681027938942694, 0.0655008412527026, 0.065868854876082, 0.064760727571228, 0.0670069066891407, 0.0642042776112608, 0.0684707754841449, 0.0690989295710015, 0.0656197917004959, 0.0651826002399615, 0.064767119317286, 0.0675922855902568, 0.0675502478014084, 0.0645398314721124, 0.0648003698415808, 0.0638325658069992, 0.0688023832450917, 0.0665535689966915, 0.067216996027588, 0.0641572013502763, 0.0649199425861354, 0.0665334287524746, 0.0654015939518924, 0.0648430495605883, 0.0648444013737319, 0.0661697040655581, 0.0646133511215998, 0.0644152157578926, 0.0637012834960848, 0.0639742972632775, 0.064051286978438, 0.0654892629927637, 0.0640131990895656, 0.0638549631510023, 0.065934197383157, 0.0643127628676575, 0.0681074178134324, 0.0637434619152091, 0.0637012834960848, 0.0637012834960848, 0.0637012834960848, 0.0646057256874644, 0.0663240564087551, 0.0643979209327216, 0.0642199072097038, 0.0639402272182903, 0.0642421417421663, 0.06407373766392, 0.0637012834960848, 0.0637012834960848, 0.0637205502553879, 0.064090378813399, 0.0637012834960848, 0.0637012834960848, 0.0637012834960848, 0.0637012834960848, 0.0638232548547496, 0.0637012834960848, 0.0640475445969101]
        self.alpha = 0.30456
        self.beta = 2.42E-4
        self.sigma = 0.001827
        self.stagenumber = 1
        #initial_asperity_radii remain unchanged
        self.initial_asperity_radii = [6.48815534415986, 5.8300280517368, 25.0502777719566, 13.0763646666972, 18.2999308454684, 9.83922871387589, 13.4559364383205, 7.82660271179036, 19.5043039696779, 18.2914403888153, 20.4408187754813, 23.5616876941751, 6.98672102424493, 8.0439038471375, 15.5133253097505, 7.27477458811898, 7.3522491043201, 15.6819534280623, 7.71428108256875, 22.7152454810843, 7.76033906811201, 11.1980288095078, 5.80575863751349, 15.8756487425998, 14.5297929842109, 5.56127806496248, 8.95328800789351, 4.49346366045399, 5.85621942795209, 6.70196261989251, 4.33259014990863, 13.3139501095604, 7.52523550320978, 3.99613725796502, 9.66481918000773, 4.3786283388573, 7.14565068226369, 7.91714458853137, 9.27786683725619, 5.18781631528845, 7.0842276381657, 9.07768852241108, 14.3660030776076, 4.64354896241863, 11.0446736227845, 8.11693409086405, 6.1466233827181, 7.89002757449292, 3.66167508922772, 5.51584088401862, 6.40092175765188, 4.60328138386405, 6.41988640431687, 12.9705970692185, 6.01260568410148, 4.86225247725925, 10.3157527716915, 5.27898643486212, 6.27596524559347, 6.1595465541872, 7.83974902352469, 4.97766368478249, 6.08208092040257, 5.8535779720908, 5.30378310004456, 7.25134427083322, 10.1356406756718, 7.68948894169465, 4.21176693062304, 8.65126854497887, 7.68678744128197, 6.06131702863553, 6.74774834684175, 5.52457291012281, 4.70942267635152, 4.22791447870817, 6.3949151136544, 9.98364092940673, 5.53292161923264, 3.91107015176309, 4.90926646444402, 7.47292319169375, 6.1669484019098, 5.30590204284172, 3.67906771045832, 4.73930754868366, 3.86810582493332, 5.69531928720185, 5.24209268251445, 3.85858164382065, 4.32911084026141, 3.84327582407532, 5.68608892048176, 4.42924839607125, 5.7649552414194, 4.79686246485538, 6.06482891835774, 8.55987730990949, 8.39724943189899, 5.28963377329843, 4.30557786055999, 5.9927764529263, 6.27758858129625, 5.70333547361801, 3.66472122054961, 4.59436405618121, 6.5447688941248, 3.68034747336779, 6.77205537553664, 3.87746613091172, 4.4454490915463, 4.764315155074, 4.75777251067302, 5.36184985948792, 5.17280511914835, 6.26819445401133, 4.47421262304455, 4.81719584689211, 9.21335330575911, 5.30036594505108, 4.15094517673742, 5.39841165030006, 4.49933431609973, 3.78434524916755, 3.9799678636927, 5.10121980443066, 5.3546990789789, 5.93881410398201, 6.84081630339532, 4.43260808075517, 4.2381490481749, 6.30909980275674, 5.54762117986757, 4.22422257802459, 5.92215542063699, 5.98005449640285, 6.33710898629102, 5.84838959469707, 3.65516430720696, 6.42461364441546, 6.94476968935811, 3.82417155274755, 5.37966755118589, 3.94064217737254, 3.93508919974228, 4.39042066015149, 5.56917065524172, 3.82125565693403, 4.66458438868764, 6.61973266578965, 5.36446568838822, 4.12604659885107, 4.35821924574264, 3.67937679503726, 3.84790992270372, 5.1044530793281, 4.91981436707496, 3.68294217705017, 4.26236432167357, 4.41475918708938, 4.55021501374343, 4.37253835892577, 3.85081147277224, 5.60194895932157, 5.66633057523617, 3.80301009945266, 3.67951943463851, 5.93179631349148, 6.6525593219995, 3.98290046810793, 4.27978015315871, 4.11916526629939, 4.62721796473641, 3.85979223115667, 5.3986190430302, 3.66807261690326, 4.5738534690537, 4.69070627277909, 3.71847088241291, 4.5596076582968, 4.61411420654928, 3.83434380823307, 4.50525907803312, 4.31495985497483, 4.63488007086583, 4.07326133813859, 4.96479773317103, 3.65198950722979, 5.11778033533489, 4.89934872638392, 5.48624742374004, 4.0053187040042, 4.24697403265991, 4.27408368622061, 4.41595450025325, 6.71613888113518, 4.54059164494665, 3.7253014730805, 4.65156111737394, 3.72101186681356, 4.39518381751327, 4.11220117573155, 4.86166067827683, 4.19690412724834, 4.37842655746792, 4.01391507346009, 4.23756540190471, 4.19552367607217, 4.67708202156898, 3.91841598790507, 3.69295787456919, 5.43988064101246, 3.74970665874239, 4.41652145569918, 3.97978538816791, 4.03859777219467, 5.25963194163641, 3.91794034215882, 5.04968214074131, 3.98034280057066, 4.12970340973691, 4.14759897085193, 3.65296234824071, 3.91859404443277, 3.9007771568077, 3.75338574901556, 4.26487335008089, 3.75387862439878, 3.67021778436586, 4.0130047600188, 3.75477366055479, 4.08546865037021, 4.07105080613585, 4.15551732660229, 3.91008825639934, 3.68921874395125, 4.87662422709204, 4.56169180614558, 4.86141270887318, 3.70400143402333, 4.36461823797297, 3.68791910072185, 5.37333223536224, 4.08817173530781, 4.21243708307451, 3.7534206903754, 3.99583015420012, 4.12685412314372, 4.23037191413517, 3.66554251626482, 3.66102135903932, 3.99442273564571, 3.79610030956471, 3.65340277801387, 3.94935563257219, 4.91928205662068, 3.67976794989735, 4.14670507648445, 4.13141660386112, 5.16427984351828, 4.50006042578851, 3.90660140893061, 3.65158669787371, 4.42078488182314, 3.73010633396041, 3.75537214131432, 4.99069392664574, 4.04599793547605, 3.79531946557157, 4.7220861386364, 3.82333457717478, 3.75890609875282, 4.80513642159104, 3.78016959697388, 3.76726389093225, 4.21034864382916, 4.25238106866411, 3.73339046420048, 4.47108564250138, 3.99783060862149, 3.80615931109729, 3.81143424550526, 3.80501171090514, 3.91880837731071, 3.81268862224108, 3.77329147438923, 4.2127525027978, 3.66129942229795, 4.15782213560537, 4.08465311159205, 3.82730978955305, 3.83755906829112, 3.68339285712096, 3.688981464988, 3.73781411873796, 3.75670107616593, 4.23173113573456, 4.43213352598904, 4.00388406314852, 3.7679153771573, 3.78133657181172, 3.7117392356282, 3.78808869930706, 4.30308682300997, 3.78433850754555, 4.03196032535163, 3.98086117372184, 3.97321187507153, 3.68685023247306, 4.12286187805586, 3.84853678191007, 3.68308877708854, 3.89699006381388, 4.04343299259356, 4.39107957159471, 3.90220077322871, 3.75311229933159, 3.77419899730083, 3.71070475604326, 3.83940787363249, 3.67882090312207, 3.92328563572018, 3.95927804107207, 3.75992800398646, 3.73487750667508, 3.7110709947096, 3.87294931694117, 3.87054060677278, 3.69804769926949, 3.71297620614358, 3.65752230423845, 3.94228632551838, 3.8134322184082, 3.85144571719305, 3.67612349510825, 3.71982756758674, 3.81227820882852, 3.74742555915815, 3.7154216980682, 3.71549915518843, 3.79143726129366, 3.702260278764, 3.69090738227839, 3.65, 3.66564333080219, 3.67005474050876, 3.752448880222, 3.6678723544287, 3.65880564267568, 3.77794303725943, 3.68503696603503, 3.90246571773247, 3.65241676809876, 3.65, 3.65, 3.65, 3.7018233513888, 3.80028144812553, 3.68991641147828, 3.67971645861461, 3.66369116190734, 3.6809904681641, 3.67134113534278, 3.65, 3.65, 3.65110396003968, 3.67229465138303, 3.65, 3.65, 3.65, 3.65, 3.65698879888588, 3.65, 3.66984030695536]
        #define time
        self.time_interval = 0
        self.state = 8.62 #standard deviation sub
        self.run_time = 0
        self.Force_path =[]
        self.speed_path = []
        self.reward_path = []
        return self.state
        
    def step(self,action1,action2):
        #there are three actions
        Force_= Force_list[action1]
        speed_= speed_list[action2]
        Time_ = 50.0
        #define reward func
        reward = self.reward(Force_,speed_,Time_)
        self.Force_path.append(Force_)
        self.speed_path.append(speed_)
        self.reward_path.append(reward)
        if self.state <= self.threshold:
            #self.reset()
            return self.state,reward,True 
            
        else:
            #get next state
            #start matlab func and input the 8 elements
            #outputs are new heights new radii
            heights_new, radii_new, sa_model = self.transition_function(Force_,speed_,Time_)
            #update enviroment state
            self.run_time += Time_
            self.time_interval +=1
            self.heights = heights_new
            self.radii = radii_new
            self.stagenumber = self.get_stagenumber(self.run_time)
            self.alpha = alpha_list[self.stagenumber-1]
            self.beta = beta_list[self.stagenumber-1]
            self.sigma = sigma_list[self.stagenumber-1]
            self.state = sa_model#new standard deviation
            
            return self.state,reward,False
        
    def reward(self,Force_,speed_,Time_):
        col = Force_list.index(Force_)
        row = speed_list.index(speed_)
        unit_power = power[row,col]
        reward = -(unit_power*Time_/3600)
        return reward
        
    def transition_function(self,Force_,speed_,Time_):
        Force =Force_
        speed = speed_
        Time = Time_
        heights_m = self.heights
        radii_m = self.radii
        alpha = self.alpha
        beta = self.beta
        sigma = self.sigma
        stagenumber = float(self.stagenumber)
        initial_asperity_radii = self.initial_asperity_radii

        heights_m1, radii_m1, sa_model = self.model_polish(Force, speed, Time, heights_m, radii_m, alpha, beta, sigma, stagenumber, initial_asperity_radii)
                                                           
        return heights_m1, radii_m1, sa_model
        
    def get_stagenumber(self,value):
        idx = segment_time[segment_time < value].argmax() #current stage
        stagenumber_ = stage_nb[idx+1] #next_stagenumber
        return stagenumber_
    
    def solve_for_d(self,Force,radii,heights,vectorlength):
        #nst = ti.time()
        #E = 0.2
        #dd = 0.01

        loop_array = np.arange(0,np.max(heights)+dd,dd)
        gap = np.zeros(vectorlength)

        for i in range(len(loop_array)):
            #find heights > d
            d = loop_array[i]
            active_nodes = np.where(heights > d)[0]
            temp = (np.power(radii[active_nodes],(1/2)))*np.power(np.absolute(heights[active_nodes]-(d*np.ones(active_nodes.shape))),(3/2))
            gap[i] = Force - (np.sum(temp)*E*(2/3))

        k_crit = np.where(gap>0)[0][0]
        d_crit = (k_crit-1) * dd-dd/2
        return d_crit
    
    def surface_roughness(self,surface_radii, initial_asperity_radii, h, N, packing_density):
        Heights_for_Sa_calc = []
        for i in range(N):
            temp_int = round((initial_asperity_radii[i]**2)/20)+2 
            vector_ = list(np.ones(temp_int)*(-(i+1)))
            Heights_for_Sa_calc = Heights_for_Sa_calc + vector_

        k = round((1 - packing_density)*np.sum(np.power(initial_asperity_radii,2))/40)
        k = int(k)
        misc_heights = (np.quantile(h,0)/2) + ((np.quantile(h,0.2) - np.quantile(h,0))*np.random.random_sample(k))
        Heights_for_Sa_calc = Heights_for_Sa_calc + list(misc_heights)
        Heights_for_Sa_calc = np.array(Heights_for_Sa_calc)
        for i in range(N):
            index = np.where(Heights_for_Sa_calc == -(i+1))[0]
            m = len(index)
            asperity_surface_ratio = np.power(surface_radii[i],2) / np.power(initial_asperity_radii[i],2)
            if (asperity_surface_ratio < 0.8):
                q1 = int(round(m*asperity_surface_ratio)+1)
                q2 = int(round((m - q1)/2))
                if q2 == 0:
                    q2 = 1
                if q1 >= m:
                    q1 = q1 - 1

                Heights_for_Sa_calc[index[0]:index[q1]] = h[i] 
                if (q2 == 1):
                    #check index[q1]
                    Heights_for_Sa_calc[index[q1]:(index[-1]+1)] = h[i] - ((2/5)*np.sqrt(np.power(initial_asperity_radii[i],2) - np.power(surface_radii[i],2)))

                if (q2 > 1):
                    #check index[q1],index[q2]
                    Heights_for_Sa_calc[index[q1]: index[q2]] = h[i] - ((2/5)*np.sqrt(np.power(initial_asperity_radii[i],2) - np.power(surface_radii[i],2)))
                    Heights_for_Sa_calc[index[q2]: (index[-1]+1)] = h[i] - (np.sqrt(np.power(initial_asperity_radii[i],2) - np.power(surface_radii[i],2)))

            if (asperity_surface_ratio >= 0.8):
                Heights_for_Sa_calc[index[0]:(index[-1]+1)] = h[i]
        void_spaces = surface_radii[surface_radii > initial_asperity_radii]
        active_nodes = len(void_spaces)
        void_space_ratio = active_nodes / N
        end_ = len(Heights_for_Sa_calc)
        start_ = end_ - round(k*void_space_ratio) #check
        for i in range(start_,end_):
            Heights_for_Sa_calc[i] = max(Heights_for_Sa_calc)

        mean_model = np.mean(Heights_for_Sa_calc)
        stdev_model = np.std(Heights_for_Sa_calc)
        M = len(Heights_for_Sa_calc)
        Sa_model = (1/M)*np.sum(np.absolute(Heights_for_Sa_calc - mean_model))
        return stdev_model,Sa_model
    
    def model_polish(self,Force, speed, Time, heights_m0, radii_m0, alpha, beta, sigma, stagenumber, initial_asperity_radii):
    
        N = 374
        packing_density = 0.7024
        MRR = 0
        Temp = 20
        # Constants for Polishing Model
        k = 1
        dt = 0.2
        a_0 = 2.5
        kappa_1 = 15* np.ones(6)
        kappa_2 = 0.674* np.ones(6)


        vector = np.arange(0, 55.9311+0.01, 0.01)
        vectorlength = len(vector)

        Horizon = np.arange(0, Time+dt, dt)
        T =np.zeros((N,len(Horizon)))  
        T[:,0] = Temp
        s = np.zeros((N,len(Horizon)))  
        s[:,0] = sigma
        h = np.zeros((N,len(Horizon))) 
        h[:,0] = heights_m0
        r = np.zeros((N,len(Horizon)))  
        r[:,0] = radii_m0
        F = np.zeros((N,len(Horizon)))                  # Force carried by asperities
        Rho = np.zeros(N) 
        #diff_heights = np.zeros(N)

        st_d_crit= np.zeros(len(Horizon))
        st_active_nodes = np.zeros(len(Horizon))

        t_interval = np.arange(dt,Time+dt,dt)

        for i in range(len(t_interval)):
            t = t_interval[i]
            d_crit = self.solve_for_d(Force, r[:,k-1], h[:,k-1], vectorlength)
            active_nodes = np.where(h[:,k-1] > d_crit)[0]   # load bearing asperities
            F[active_nodes, k] = (2/3)*E* np.sqrt(np.absolute(r[active_nodes,k-1]))*(np.power((np.absolute(h[active_nodes,k-1] - d_crit*np.ones(len(active_nodes)))),(3/2)))   # force carried by each node. 

            # Temp dynamics
            T[:,k] = T[:,k-1] + dt*(-alpha*(T[:,k-1]- Temp) + kappa_1[0]*F[:, k-1] + kappa_2[0]*speed)

            # Sigma dynamics
            s[:,k] = s[:,k-1] + (beta/Temp)*(T[:,k] - T[:,k-1])

            # Differential height dynamics
            diff_heights = dt*s[:,k-1]*T[:,k-1]*F[:,k-1]
            h[:,k] = h[:,k-1] - diff_heights

            # Radii update and MRR computation
            Rho = h[:,k-1]/(4*r[:,k-1])
            MR_coeff_eta = np.exp(-(h[:,k-1]/np.power(max(heights_m0),2)))*np.exp(-np.power((a_0/r[:,k-1]),0.0833))
            r[:,k] = r[:,k-1] + (MR_coeff_eta*Rho*diff_heights)

            # store
            st_d_crit[k-1] = d_crit
            st_active_nodes[k-1] = len(active_nodes)
            k = k+1

        heights_m1 = h[:,len(Horizon)-1]
        radii_m1 = r[:,len(Horizon)-1]

        stdev_model_1, Sa_model_1 = self.surface_roughness(radii_m1, initial_asperity_radii, heights_m1, N, packing_density)

        return heights_m1, radii_m1, Sa_model_1

In [None]:
class SharedAdam(T.optim.Adam):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), eps=1e-8,
            weight_decay=0):
        super(SharedAdam, self).__init__(params, lr=lr, betas=betas, eps=eps,
                weight_decay=weight_decay)

        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['step'] = 0
                state['exp_avg'] = T.zeros_like(p.data)
                state['exp_avg_sq'] = T.zeros_like(p.data)

                state['exp_avg'].share_memory_()
                state['exp_avg_sq'].share_memory_()

In [None]:
class ActorCritic(nn.Module):
    def __init__(self, input_dims, n_actions_1,n_actions_2, gamma=0.99):
        super(ActorCritic, self).__init__()

        self.gamma = gamma

        self.pi_a1 = nn.Linear(*input_dims, 128)
        self.pi_a2 = nn.Linear(*input_dims, 128)
        self.v1 = nn.Linear(*input_dims, 128)
        
        self.pi_a1_out = nn.Linear(128, n_actions_1)
        self.pi_a2_out = nn.Linear(128, n_actions_2)
        self.v = nn.Linear(128, 1)

        self.rewards = []
        self.actions_1 = []
        self.actions_2 = []
        self.states = []
        
    def remember(self, state, action1,action2, reward):
        self.states.append(state)
        self.actions_1.append(int(action1))
        self.actions_2.append(int(action2))
        self.rewards.append(reward)

    def clear_memory(self):
        self.states = []
        self.actions_1 = []
        self.actions_2 = []
        self.rewards = []
        
    def forward(self, state):
        pi_a1 = F.relu(self.pi_a1(state))
        pi_a2 = F.relu(self.pi_a2(state))
        v1 = F.relu(self.v1(state))

        action1_dist = self.pi_a1_out(pi_a1)
        action2_dist = self.pi_a2_out(pi_a2)
        v = self.v(v1)

        return action1_dist,action2_dist, v
    
    #process the state before giving as input to calc_R
    def calc_R(self, done):
        states = T.tensor(self.states, dtype=T.float)
        _,_, v = self.forward(T.reshape(states,(-1,1)))

        R = v[-1]*(1-int(done))

        batch_return = []
        for reward in self.rewards[::-1]:
            R = reward + self.gamma*R
            batch_return.append(R)
        batch_return.reverse()
        batch_return = T.tensor(batch_return, dtype=T.float)

        return batch_return
    
    #process the states before giving as input to calc_loss
    def calc_loss(self, done):
        states = T.tensor(self.states, dtype=T.float)
        actions_1 = T.tensor(self.actions_1, dtype=T.float)
        actions_2 = T.tensor(self.actions_2, dtype=T.float)

        returns = self.calc_R(done)

        action1_dist,action2_dist,values = self.forward(T.reshape(states,(-1,1)))
        values = values.squeeze()
        critic_loss = (returns-values)**2
        probs_1 = T.softmax(action1_dist, dim=-1)
        dist = Categorical(probs_1)
        log_probs = dist.log_prob(actions_1)
        actor_loss_1 = -log_probs*(returns-values)
        probs_2 = T.softmax(action2_dist, dim=-1)
        dist = Categorical(probs_2)
        log_probs = dist.log_prob(actions_2)
        actor_loss_2 = -log_probs*(returns-values)
        total_loss = (critic_loss + actor_loss_1 + actor_loss_2).mean()

    
        return total_loss
    
    #process the states before giving as input to choose_action
    def choose_action(self, observation):
        state = T.tensor([observation], dtype=T.float)
        action1_dist,action2_dist, v = self.forward(state)
        probs_1 = T.softmax(action1_dist, dim=-1)
        dist = Categorical(probs_1)
        action_1 = dist.sample().numpy()
        
        probs_2 = T.softmax(action2_dist, dim=-1)
        dist = Categorical(probs_2)
        action_2 = dist.sample().numpy()

        return action_1,action_2

In [None]:
class Agent(mp.Process):
    def __init__(self, global_actor_critic, optimizer, input_dims, n_actions_1,n_actions_2, 
                gamma, lr, name, global_ep_idx, Environment, threshold,scaler,N_GAMES,T_MAX):
        super(Agent, self).__init__()
        self.local_actor_critic = ActorCritic(input_dims, n_actions_1,n_actions_2, gamma)
        self.global_actor_critic = global_actor_critic
        self.name = 'w%02i' % name
        self.episode_idx = global_ep_idx
        self.env = Environment(threshold)
        self.optimizer = optimizer
        #initiate state processor
        
    def process_state(self,state):
        state = np.array(state).reshape(-1, 1)
        scaled = scaler.transform(state)
        return scaled[0][0]

    def run(self):
        t_step = 1
        while self.episode_idx.value < N_GAMES:
            done = False
            observation = self.env.reset()
            score = 0
            steps_taken = 1
            self.local_actor_critic.clear_memory()
            while not done:
                #process state
                processed_obs = self.process_state(observation)
                action_1,action_2 = self.local_actor_critic.choose_action(processed_obs)
                observation_, reward, done = self.env.step(action_1,action_2)
                score += reward
                #process state
                self.local_actor_critic.remember(processed_obs, action_1,action_2, reward)
                if t_step % T_MAX == 0 or done:
                    loss = self.local_actor_critic.calc_loss(done)
                    self.optimizer.zero_grad()
                    loss.backward()
                    for local_param, global_param in zip(
                            self.local_actor_critic.parameters(),
                            self.global_actor_critic.parameters()):
                        global_param._grad = local_param.grad
                    self.optimizer.step()
                    self.local_actor_critic.load_state_dict(
                            self.global_actor_critic.state_dict())
                    self.local_actor_critic.clear_memory()
                t_step += 1
                steps_taken += 1
                observation = observation_
            with self.episode_idx.get_lock():
                self.episode_idx.value += 1
            print(self.name, 'episode ', self.episode_idx.value, 'reward %.1f' % score,
                 'Force:',self.env.Force_path,'speed:',self.env.speed_path,'reward:',self.env.reward_path,
                 'steps taken:',steps_taken)

In [None]:
if __name__ == '__main__':
    lr = 1e-3
    n_actions_1 = 11
    n_actions_2 = 5
    input_dims = [1]
    E = 0.2
    dd = 0.01
    threshold = 6

    Force_list = [22.24,26.69,31.14,35.58,40.04,44.48,48.93,53.38,57.83,62.28,66.72]
    speed_list = [1.73,1.88,2.04,2.2,2.36]

    power = np.array([[135,135,136,139,142,147,153,161,167,172,178],
            [143,143,145,147,150,155,161,168,176,181,188],
            [154,154,155,159,163,168,175,181,187,193,197],
            [162,164,165,168,172,177,182,187,192,198,202],
            [166,167,170,175,179,184,193,197,200,205,209]])

    alpha_list = [0.30456,0.108642,1.372173,2.013025,1.343399,1.650823]
    beta_list = [0.000242,0.000029,0.000233,0.000038,0.0019,0.001363]
    sigma_list = [0.001827,0.000103,0.003548,0.002703,0.004522,0.007872]


    segment_time = np.array([0,300,600,1200,1800,2400,3300])
    stage_nb = np.array([0,1,2,3,4,5,6])
    
    #featurize the state space
    def sample_state_space():
        state_space = np.random.uniform(threshold,8.62)
        return np.array([state_space],dtype=np.float32)
    
    observation_examples = np.array([sample_state_space() for x in range(10000)])
    scaler = sklearn.preprocessing.StandardScaler()
    scaler.fit(observation_examples)

    N_GAMES = 3000
    T_MAX = 5
    global_actor_critic = ActorCritic(input_dims, n_actions_1,n_actions_2)
    global_actor_critic.share_memory()
    optim = SharedAdam(global_actor_critic.parameters(), lr=lr, 
                        betas=(0.92, 0.999))
    global_ep = mp.Value('i', 0)

    workers = [Agent(global_actor_critic,
                    optim,
                    input_dims,
                    n_actions_1,
                    n_actions_2,
                    gamma=0.99,
                    lr=lr,
                    name=i,
                    global_ep_idx=global_ep,
                    Environment=Environment, threshold=threshold,scaler=scaler,
                    N_GAMES=N_GAMES,T_MAX=T_MAX) for i in range(1)] #mp.cpu_count() 
    [w.start() for w in workers]
    [w.join() for w in workers]