In [3]:
import sys
assert sys.version_info >= (3, 5) # make sure we have Python 3.5+

from pyspark.sql import SparkSession, functions, types
spark = SparkSession.builder.appName('colour prediction').getOrCreate()
spark.sparkContext.setLogLevel('WARN')
assert spark.version >= '2.3' # make sure we have Spark 2.3+

from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, VectorAssembler, SQLTransformer
from pyspark.ml.classification import MultilayerPerceptronClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.regression import GBTRegressor
from pyspark.ml.evaluation import RegressionEvaluator
import pandas as pd


def main(inputs):
    
    colour_schema = types.StructType([
    types.StructField('consumption', types.FloatType(), False),
    types.StructField('exports', types.FloatType(), False),
    types.StructField('imports', types.FloatType(), False),
    types.StructField('investment', types.FloatType(), False),
    types.StructField('others', types.FloatType(), False),
    types.StructField('expenditure', types.FloatType(), False),
    types.StructField('year', types.IntegerType(), False),
    types.StructField('gdp', types.FloatType(), False),
    types.StructField('country', types.StringType(), False),])
    
    
    data = spark.read.csv(inputs,schema=colour_schema)
    
 
  
  
    rgb_assembler = VectorAssembler(inputCols=["consumption","exports","imports","investment","others","expenditure","gdp","year"], outputCol="features")
   
    
    
    word_indexer = StringIndexer(inputCol="country", outputCol="w_index")
    
    #train, validation = data.randomSplit([0.75, 0.25])
    train = data.filter( (data['year']!='2017') )


    validation = data.filter( (data['year']=='2017') )
    

    
    

    classifier = GBTRegressor(featuresCol='features', labelCol='gdp', maxIter=100)
    
    rgb_pipeline = Pipeline(stages=[rgb_assembler,word_indexer, classifier])
    
    rgb_model = rgb_pipeline.fit(train)
    
    predictions=rgb_model.transform(validation)
    
    check=predictions.select(predictions['country'],predictions['gdp'],predictions['year'])
    
    #evaluator = RegressionEvaluator(predictionCol='prediction', labelCol='gdp',metricName='r2')
    
    #score = evaluator.evaluate(predictions)
    #print(score)
    
    #evaluator
    check.show()
    
    #df = pd.DataFrame(predictions.take(300), columns=predictions.columns)


    #df.to_csv('/Users/mananparasher/Documents/big_data_project/gdp/check12.csv')
   
 
    

    
if __name__ == '__main__':
    inputs = "/Users/mananparasher/Documents/big_data_project/gdp/check7.csv"
    main(inputs)

+-------------+-------------+----+
|      country|          gdp|year|
+-------------+-------------+----+
|       Canada|1.69137828E10|2017|
|        India|2.81999992E12|2017|
|United States|1.94000003E13|2017|
|        China|1.22000004E13|2017|
|        Japan| 4.8700001E12|2017|
+-------------+-------------+----+

